Пример #1
0
class Seq2SeqTrainer:
    """
    Seq2SeqTrainer
    """
    def __init__(self,
                 model,
                 criterion,
                 opt_config,
                 scheduler_config,
                 print_freq=10,
                 save_freq=1000,
                 grad_clip=float('inf'),
                 batch_first=False,
                 save_info={},
                 save_path='.',
                 train_iterations=0,
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 math='fp32',
                 cuda=True,
                 distributed=False,
                 intra_epoch_eval=0,
                 iter_size=1,
                 translator=None,
                 verbose=False):
        """
        Constructor for the Seq2SeqTrainer.

        :param model: model to train
        :param criterion: criterion (loss function)
        :param opt_config: dictionary with options for the optimizer
        :param scheduler_config: dictionary with options for the learning rate
            scheduler
        :param print_freq: prints short summary every 'print_freq' iterations
        :param save_freq: saves checkpoint every 'save_freq' iterations
        :param grad_clip: coefficient for gradient clipping
        :param batch_first: if True the model uses (batch,seq,feature) tensors,
            if false the model uses (seq, batch, feature)
        :param save_info: dict with additional state stored in each checkpoint
        :param save_path: path to the directiory for checkpoints
        :param train_iterations: total number of training iterations to execute
        :param checkpoint_filename: name of files with checkpoints
        :param keep_checkpoints: max number of checkpoints to keep
        :param math: arithmetic type
        :param cuda: if True use cuda, if False train on cpu
        :param distributed: if True run distributed training
        :param intra_epoch_eval: number of additional eval runs within each
            training epoch
        :param iter_size: number of iterations between weight updates
        :param translator: instance of Translator, runs inference on test set
        :param verbose: enables verbose logging
        """
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion
        self.epoch = 0
        self.save_info = save_info
        self.save_path = save_path
        self.save_freq = save_freq
        self.save_counter = 0
        self.checkpoint_filename = checkpoint_filename
        self.checkpoint_counter = cycle(range(keep_checkpoints))
        self.opt_config = opt_config
        self.cuda = cuda
        self.distributed = distributed
        self.print_freq = print_freq
        self.batch_first = batch_first
        self.verbose = verbose
        self.loss = None
        self.translator = translator
        self.intra_epoch_eval = intra_epoch_eval
        self.iter_size = iter_size

        if cuda:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()

        if math == 'fp16':
            self.model = self.model.half()

        if distributed:
            self.model = DDP(self.model)

        if math == 'fp16':
            self.fp_optimizer = Fp16Optimizer(self.model, grad_clip)
            params = self.fp_optimizer.fp32_params
        elif math == 'fp32':
            self.fp_optimizer = Fp32Optimizer(self.model, grad_clip)
            params = self.model.parameters()

        opt_name = opt_config.pop('optimizer')
        self.optimizer = torch.optim.__dict__[opt_name](params, **opt_config)
        logging.info(f'Using optimizer: {self.optimizer}')

        self.scheduler = WarmupMultiStepLR(self.optimizer, train_iterations,
                                           **scheduler_config)

    def iterate(self, src, tgt, update=True, training=True):
        """
        Performs one iteration of the training/validation.

        :param src: batch of examples from the source language
        :param tgt: batch of examples from the target language
        :param update: if True: optimizer does update of the weights
        :param training: if True: executes optimizer
        """
        src, src_length = src
        tgt, tgt_length = tgt
        src_length = torch.LongTensor(src_length)
        tgt_length = torch.LongTensor(tgt_length)

        num_toks = {}
        num_toks['tgt'] = int(sum(tgt_length - 1))
        num_toks['src'] = int(sum(src_length))

        if self.cuda:
            src = src.cuda()
            src_length = src_length.cuda()
            tgt = tgt.cuda()

        if self.batch_first:
            output = self.model(src, src_length, tgt[:, :-1])
            tgt_labels = tgt[:, 1:]
            T, B = output.size(1), output.size(0)
        else:
            output = self.model(src, src_length, tgt[:-1])
            tgt_labels = tgt[1:]
            T, B = output.size(0), output.size(1)

        loss = self.criterion(output.view(T * B, -1),
                              tgt_labels.contiguous().view(-1))

        loss_per_batch = loss.item()
        loss /= (B * self.iter_size)

        if training:
            self.fp_optimizer.step(loss, self.optimizer, self.scheduler,
                                   update)

        loss_per_token = loss_per_batch / num_toks['tgt']
        loss_per_sentence = loss_per_batch / B

        return loss_per_token, loss_per_sentence, num_toks

    def feed_data(self, data_loader, training=True):
        """
        Runs training or validation on batches from data_loader.

        :param data_loader: data loader
        :param training: if True runs training else runs validation
        """
        if training:
            assert self.optimizer is not None
            eval_fractions = np.linspace(0, 1, self.intra_epoch_eval + 2)[1:-1]
            iters_with_update = len(data_loader) // self.iter_size
            eval_iters = (eval_fractions * iters_with_update).astype(int)
            eval_iters = eval_iters * self.iter_size
            eval_iters = set(eval_iters)

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses_per_token = AverageMeter(skip_first=False)
        losses_per_sentence = AverageMeter(skip_first=False)

        tot_tok_time = AverageMeter()
        src_tok_time = AverageMeter()
        tgt_tok_time = AverageMeter()

        batch_size = data_loader.batch_size

        end = time.time()
        for i, (src, tgt) in enumerate(data_loader):
            self.save_counter += 1
            # measure data loading time
            data_time.update(time.time() - end)

            update = False
            if i % self.iter_size == self.iter_size - 1:
                update = True

            # do a train/evaluate iteration
            stats = self.iterate(src, tgt, update, training=training)
            loss_per_token, loss_per_sentence, num_toks = stats

            # measure accuracy and record loss
            losses_per_token.update(loss_per_token, num_toks['tgt'])
            losses_per_sentence.update(loss_per_sentence, batch_size)

            # measure elapsed time
            elapsed = time.time() - end
            batch_time.update(elapsed)
            src_tok_time.update(num_toks['src'] / elapsed)
            tgt_tok_time.update(num_toks['tgt'] / elapsed)
            tot_num_toks = num_toks['tgt'] + num_toks['src']
            tot_tok_time.update(tot_num_toks / elapsed)
            self.loss = losses_per_token.avg

            if training and i in eval_iters:
                test_bleu, _ = self.translator.run(calc_bleu=True,
                                                   epoch=self.epoch,
                                                   iteration=i)

                log = []
                log += [f'TRAIN [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'BLEU: {test_bleu:.2f}']
                log = '\t'.join(log)
                logging.info(log)

                self.model.train()
                self.preallocate(data_loader, training=True)

            if i % self.print_freq == 0:
                phase = 'TRAIN' if training else 'VALIDATION'
                log = []
                log += [f'{phase} [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})']
                log += [f'Data {data_time.val:.2e} ({data_time.avg:.2e})']
                log += [
                    f'Tok/s {tot_tok_time.val:.0f} ({tot_tok_time.avg:.0f})'
                ]
                if self.verbose:
                    log += [
                        f'Src tok/s {src_tok_time.val:.0f} ({src_tok_time.avg:.0f})'
                    ]
                    log += [
                        f'Tgt tok/s {tgt_tok_time.val:.0f} ({tgt_tok_time.avg:.0f})'
                    ]
                    log += [
                        f'Loss/sentence {losses_per_sentence.val:.1f} ({losses_per_sentence.avg:.1f})'
                    ]
                log += [
                    f'Loss/tok {losses_per_token.val:.4f} ({losses_per_token.avg:.4f})'
                ]
                if training:
                    lr = self.optimizer.param_groups[0]['lr']
                    log += [f'LR {lr:.3e}']
                log = '\t'.join(log)
                logging.info(log)

            save_chkpt = (self.save_counter %
                          self.save_freq) == (self.save_freq - 1)
            if training and save_chkpt:
                self.save_counter = 0
                self.save_info['iteration'] = i
                identifier = next(self.checkpoint_counter, -1)
                if identifier != -1:
                    with sync_workers() as rank:
                        if rank == 0:
                            self.save(identifier=identifier)

            end = time.time()

        tot_tok_time.reduce('sum')
        losses_per_token.reduce('mean')

        return losses_per_token.avg, tot_tok_time.avg

    def preallocate(self, data_loader, training):
        """
        Generates maximum sequence length batch and runs forward and backward
        pass without updating model parameters.

        :param data_loader: data loader
        :param training: if True preallocates memory for backward pass
        """
        batch_size = data_loader.batch_size
        max_len = data_loader.dataset.max_len

        src_length = [max_len] * batch_size
        tgt_length = [max_len] * batch_size

        if self.batch_first:
            shape = (batch_size, max_len)
        else:
            shape = (max_len, batch_size)

        src = torch.full(shape, 4, dtype=torch.int64)
        tgt = torch.full(shape, 4, dtype=torch.int64)
        src = src, src_length
        tgt = tgt, tgt_length
        self.iterate(src, tgt, update=False, training=training)
        self.model.zero_grad()

    def optimize(self, data_loader):
        """
        Sets model in training mode, preallocates memory and runs training on
        data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(True)
        self.model.train()
        torch.cuda.empty_cache()
        self.preallocate(data_loader, training=True)
        output = self.feed_data(data_loader, training=True)
        self.model.zero_grad()
        torch.cuda.empty_cache()
        return output

    def evaluate(self, data_loader):
        """
        Sets model in eval mode, disables gradients, preallocates memory and
        runs validation on data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(False)
        self.model.eval()
        torch.cuda.empty_cache()
        self.preallocate(data_loader, training=False)
        output = self.feed_data(data_loader, training=False)
        self.model.zero_grad()
        torch.cuda.empty_cache()
        return output

    def load(self, filename):
        """
        Loads checkpoint from filename.

        :param filename: path to the checkpoint file
        """
        if os.path.isfile(filename):
            checkpoint = torch.load(filename, map_location={'cuda:0': 'cpu'})
            if self.distributed:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.fp_optimizer.initialize_model(self.model)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.epoch = checkpoint['epoch']
            self.loss = checkpoint['loss']
            logging.info(f'Loaded checkpoint {filename} (epoch {self.epoch})')
        else:
            logging.error(f'Invalid checkpoint: {filename}')

    def save(self, identifier=None, is_best=False, save_all=False):
        """
        Stores checkpoint to a file.

        :param identifier: identifier for periodic checkpoint
        :param is_best: if True stores checkpoint to 'model_best.pth'
        :param save_all: if True stores checkpoint after completed training
            epoch
        """
        def write_checkpoint(state, filename):
            filename = os.path.join(self.save_path, filename)
            logging.info(f'Saving model to {filename}')
            torch.save(state, filename)

        if self.distributed:
            model_state = self.model.module.state_dict()
        else:
            model_state = self.model.state_dict()

        state = {
            'epoch': self.epoch,
            'state_dict': model_state,
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'loss': getattr(self, 'loss', None),
        }
        state = dict(list(state.items()) + list(self.save_info.items()))

        if identifier is not None:
            filename = self.checkpoint_filename % identifier
            write_checkpoint(state, filename)

        if is_best:
            filename = 'model_best.pth'
            write_checkpoint(state, filename)

        if save_all:
            filename = f'checkpoint_epoch_{self.epoch:03d}.pth'
            write_checkpoint(state, filename)
Пример #2
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch FastPitch Training',
                                     allow_abbrev=False)
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        local_rank = int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
    else:
        local_rank = args.rank
        world_size = args.world_size
    distributed_run = world_size > 1

    torch.manual_seed(args.seed + local_rank)
    np.random.seed(args.seed + local_rank)

    if local_rank == 0:
        if not os.path.exists(args.output):
            os.makedirs(args.output)
        init_dllogger(args.log_file)
    else:
        init_dllogger(dummy=True)

    for k, v in vars(args).items():
        DLLogger.log(step="PARAMETER", data={k: v})

    parser = models.parse_model_args('FastPitch', parser)
    args, unk_args = parser.parse_known_args()
    if len(unk_args) > 0:
        raise ValueError(f'Invalid options {unk_args}')

    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if distributed_run:
        init_distributed(args, world_size, local_rank, args.group_name)

    device = torch.device('cuda' if args.cuda else 'cpu')
    model_config = models.get_model_config('FastPitch', args)
    model = models.get_model('FastPitch', model_config, device)

    # Store pitch mean/std as params to translate from Hz during inference
    fpath = common.utils.stats_filename(args.dataset_path, args.training_files,
                                        'pitch_char')
    with open(args.pitch_mean_std_file, 'r') as f:
        stats = json.load(f)
    model.pitch_mean[0] = stats['mean']
    model.pitch_std[0] = stats['std']

    kw = dict(lr=args.learning_rate,
              betas=(0.9, 0.98),
              eps=1e-9,
              weight_decay=args.weight_decay)
    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), **kw)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), **kw)
    else:
        raise ValueError

    if args.amp_run:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if args.ema_decay > 0:
        ema_model = copy.deepcopy(model)
    else:
        ema_model = None

    if distributed_run:
        model = DDP(model)

    start_epoch = [1]

    assert args.checkpoint_path is None or args.checkpoint_resume is False, (
        "Specify a single checkpoint source")
    if args.checkpoint_path is not None:
        ch_fpath = args.checkpoint_path
    elif args.checkpoint_resume:
        ch_fpath = last_checkpoint(args.output)
    else:
        ch_fpath = None

    if ch_fpath is not None:
        load_checkpoint(local_rank, model, ema_model, optimizer, start_epoch,
                        model_config, args.amp_run, ch_fpath, world_size)

    start_epoch = start_epoch[0]

    criterion = loss_functions.get_loss_function(
        'FastPitch',
        dur_predictor_loss_scale=args.dur_predictor_loss_scale,
        pitch_predictor_loss_scale=args.pitch_predictor_loss_scale)

    collate_fn = data_functions.get_collate_function('FastPitch')
    trainset = data_functions.get_data_loader('FastPitch', args.dataset_path,
                                              args.training_files, args)
    valset = data_functions.get_data_loader('FastPitch', args.dataset_path,
                                            args.validation_files, args)
    if distributed_run:
        train_sampler, shuffle = DistributedSampler(trainset), False
    else:
        train_sampler, shuffle = None, True

    train_loader = DataLoader(trainset,
                              num_workers=16,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)

    batch_to_gpu = data_functions.get_batch_to_gpu('FastPitch')

    model.train()

    train_tblogger = TBLogger(local_rank, args.output, 'train')
    val_tblogger = TBLogger(local_rank, args.output, 'val', dummies=True)
    if args.ema_decay > 0:
        val_ema_tblogger = TBLogger(local_rank, args.output, 'val_ema')

    val_loss = 0.0
    total_iter = 0
    torch.cuda.synchronize()
    for epoch in range(start_epoch, args.epochs + 1):
        epoch_start_time = time.time()

        epoch_loss = 0.0
        epoch_mel_loss = 0.0
        epoch_num_frames = 0
        epoch_frames_per_sec = 0.0

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        accumulated_steps = 0
        iter_loss = 0
        iter_num_frames = 0
        iter_meta = {}

        epoch_iter = 0
        num_iters = len(train_loader) // args.gradient_accumulation_steps
        for batch in train_loader:
            if accumulated_steps == 0:
                if epoch_iter == num_iters:
                    break
                total_iter += 1
                epoch_iter += 1
                iter_start_time = time.time()
                start = time.perf_counter()

                old_lr = optimizer.param_groups[0]['lr']
                adjust_learning_rate(total_iter, optimizer, args.learning_rate,
                                     args.warmup_steps)
                new_lr = optimizer.param_groups[0]['lr']

                if new_lr != old_lr:
                    dllog_lrate_change = f'{old_lr:.2E} -> {new_lr:.2E}'
                    train_tblogger.log_value(total_iter, 'lrate', new_lr)
                else:
                    dllog_lrate_change = None

                model.zero_grad()

            x, y, num_frames = batch_to_gpu(batch)
            y_pred = model(x, use_gt_durations=True)
            loss, meta = criterion(y_pred, y)

            loss /= args.gradient_accumulation_steps
            meta = {
                k: v / args.gradient_accumulation_steps
                for k, v in meta.items()
            }

            if args.amp_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, world_size).item()
                reduced_num_frames = reduce_tensor(num_frames.data, 1).item()
                meta = {
                    k: reduce_tensor(v, world_size)
                    for k, v in meta.items()
                }
            else:
                reduced_loss = loss.item()
                reduced_num_frames = num_frames.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            accumulated_steps += 1
            iter_loss += reduced_loss
            iter_num_frames += reduced_num_frames
            iter_meta = {k: iter_meta.get(k, 0) + meta.get(k, 0) for k in meta}

            if accumulated_steps % args.gradient_accumulation_steps == 0:

                train_tblogger.log_grads(total_iter, model)
                if args.amp_run:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.grad_clip_thresh)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.grad_clip_thresh)

                optimizer.step()
                apply_ema_decay(model, ema_model, args.ema_decay)

                iter_stop_time = time.time()
                iter_time = iter_stop_time - iter_start_time
                frames_per_sec = iter_num_frames / iter_time
                epoch_frames_per_sec += frames_per_sec
                epoch_loss += iter_loss
                epoch_num_frames += iter_num_frames
                iter_mel_loss = iter_meta['mel_loss'].item()
                epoch_mel_loss += iter_mel_loss

                DLLogger.log(
                    (epoch, epoch_iter, num_iters),
                    OrderedDict([('train_loss', iter_loss),
                                 ('train_mel_loss', iter_mel_loss),
                                 ('train_frames/s', frames_per_sec),
                                 ('took', iter_time),
                                 ('lrate_change', dllog_lrate_change)]))
                train_tblogger.log_meta(total_iter, iter_meta)

                accumulated_steps = 0
                iter_loss = 0
                iter_num_frames = 0
                iter_meta = {}

        # Finished epoch
        epoch_stop_time = time.time()
        epoch_time = epoch_stop_time - epoch_start_time

        DLLogger.log((epoch, ),
                     data=OrderedDict([
                         ('avg_train_loss', epoch_loss / epoch_iter),
                         ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
                         ('avg_train_frames/s', epoch_num_frames / epoch_time),
                         ('took', epoch_time)
                     ]))

        tik = time.time()
        val_loss, meta, num_frames = validate(model,
                                              criterion,
                                              valset,
                                              args.batch_size,
                                              world_size,
                                              collate_fn,
                                              distributed_run,
                                              local_rank,
                                              batch_to_gpu,
                                              use_gt_durations=True)
        tok = time.time()

        DLLogger.log((epoch, ),
                     data=OrderedDict([
                         ('val_loss', val_loss),
                         ('val_mel_loss', meta['mel_loss'].item()),
                         ('val_frames/s', num_frames / (tok - tik)),
                         ('took', tok - tik),
                     ]))
        val_tblogger.log_meta(total_iter, meta)

        if args.ema_decay > 0:
            tik_e = time.time()
            val_loss_e, meta_e, num_frames_e = validate(ema_model,
                                                        criterion,
                                                        valset,
                                                        args.batch_size,
                                                        world_size,
                                                        collate_fn,
                                                        distributed_run,
                                                        local_rank,
                                                        batch_to_gpu,
                                                        use_gt_durations=True)
            tok_e = time.time()

            DLLogger.log(
                (epoch, ),
                data=OrderedDict([
                    ('val_ema_loss', val_loss_e),
                    ('val_ema_mel_loss', meta_e['mel_loss'].item()),
                    ('val_ema_frames/s', num_frames_e / (tok_e - tik_e)),
                    ('took', tok_e - tik_e),
                ]))
            val_ema_tblogger.log_meta(total_iter, meta)

        if (epoch > 0 and args.epochs_per_checkpoint > 0
                and (epoch % args.epochs_per_checkpoint == 0)
                and local_rank == 0):

            checkpoint_path = os.path.join(args.output,
                                           f"FastPitch_checkpoint_{epoch}.pt")
            save_checkpoint(local_rank, model, ema_model, optimizer, epoch,
                            model_config, args.amp_run, checkpoint_path)
        if local_rank == 0:
            DLLogger.flush()

    # Finished training
    DLLogger.log((),
                 data=OrderedDict([
                     ('avg_train_loss', epoch_loss / epoch_iter),
                     ('avg_train_mel_loss', epoch_mel_loss / epoch_iter),
                     ('avg_train_frames/s', epoch_num_frames / epoch_time),
                 ]))
    DLLogger.log((),
                 data=OrderedDict([
                     ('val_loss', val_loss),
                     ('val_mel_loss', meta['mel_loss'].item()),
                     ('val_frames/s', num_frames / (tok - tik)),
                 ]))
    if local_rank == 0:
        DLLogger.flush()
Пример #3
0
class RunManager:
    def __init__(self, path, net, run_config: RunConfig, out_log=True):
        self.path = path
        self.net = net
        self.run_config = run_config
        self.out_log = out_log

        self._logs_path, self._save_path = None, None
        self.best_acc = 0
        self.start_epoch = 0
        gpu = self.run_config.local_rank
        torch.cuda.set_device(gpu)

        # initialize model (default)
        self.net.init_model(run_config.model_init, run_config.init_div_groups)

        # net info
        self.net = self.net.cuda()
        if run_config.local_rank == 0:
            self.print_net_info()

        if self.run_config.sync_bn:
            self.net = apex.parallel.convert_syncbn_model(self.net)
        print('local_rank: %d' % self.run_config.local_rank)

        self.run_config.init_lr = self.run_config.init_lr * float(
            self.run_config.train_batch_size *
            self.run_config.world_size) / 256.
        self.criterion = nn.CrossEntropyLoss()
        if self.run_config.no_decay_keys:
            keys = self.run_config.no_decay_keys.split('#')
            self.optimizer = self.run_config.build_optimizer([
                self.net.get_parameters(
                    keys, mode='exclude'),  # parameters with weight decay
                self.net.get_parameters(
                    keys, mode='include'),  # parameters without weight decay
            ])
        else:
            self.optimizer = self.run_config.build_optimizer(
                self.net.weight_parameters())
        # self.net, self.optimizer = amp.initialize(self.net, self.optimizer, opt_level='O1')
        self.net = DDP(self.net, delay_allreduce=True)
        cudnn.benchmark = True

    """ save path and log path """

    @property
    def save_path(self):
        if self._save_path is None:
            save_path = os.path.join(self.path, 'checkpoint')
            os.makedirs(save_path, exist_ok=True)
            self._save_path = save_path
        return self._save_path

    @property
    def logs_path(self):
        if self._logs_path is None:
            logs_path = os.path.join(self.path, 'logs')
            os.makedirs(logs_path, exist_ok=True)
            self._logs_path = logs_path
        return self._logs_path

    """ net info """

    def reset_model(self, model, model_origin=None):
        self.net = model
        self.net.init_model(self.run_config.model_init,
                            self.run_config.init_div_groups)
        if model_origin != None:
            if self.run_config.local_rank == 0:
                print('-' * 30 + ' start pruning ' + '-' * 30)
            get_unpruned_weights(self.net, model_origin)
            if self.run_config.local_rank == 0:
                print('-' * 30 + ' end pruning ' + '-' * 30)
        # net info
        self.net = self.net.cuda()
        if self.run_config.local_rank == 0:
            self.print_net_info()

        if self.run_config.sync_bn:
            self.net = apex.parallel.convert_syncbn_model(self.net)
        print('local_rank: %d' % self.run_config.local_rank)

        self.criterion = nn.CrossEntropyLoss()
        if self.run_config.no_decay_keys:
            keys = self.run_config.no_decay_keys.split('#')
            self.optimizer = self.run_config.build_optimizer([
                self.net.get_parameters(
                    keys, mode='exclude'),  # parameters with weight decay
                self.net.get_parameters(
                    keys, mode='include'),  # parameters without weight decay
            ])
        else:
            self.optimizer = self.run_config.build_optimizer(
                self.net.weight_parameters())
        # model, self.optimizer = amp.initialize(model, self.optimizer,
        #                                        opt_level='O2',
        #                                        keep_batchnorm_fp32=True,
        #                                        loss_scale=1.0
        #                                        )
        self.net = DDP(self.net, delay_allreduce=True)
        cudnn.benchmark = True
        # if model_origin!=None:
        #     if self.run_config.local_rank==0:
        #         print('-'*30+' start training bn '+'-'*30)
        #     self.train_bn(1)
        #     if self.run_config.local_rank==0:
        #         print('-'*30+' end training bn '+'-'*30)

    # noinspection PyUnresolvedReferences
    def net_flops(self):
        data_shape = [1] + list(self.run_config.data_provider.data_shape)

        net = self.net
        input_var = torch.zeros(data_shape).cuda()
        with torch.no_grad():
            flops = profile_macs(net, input_var)
        return flops

    def print_net_info(self):
        # parameters
        total_params = count_parameters(self.net)
        if self.out_log:
            print('Total training params: %.2fM' % (total_params / 1e6))
        net_info = {
            'param': '%.2fM' % (total_params / 1e6),
        }

        # flops
        flops = self.net_flops()
        if self.out_log:
            print('Total FLOPs: %.1fM' % (flops / 1e6))
        net_info['flops'] = '%.1fM' % (flops / 1e6)

        # config
        if self.out_log:
            print('Net config: ' + str(self.net.config))
        net_info['config'] = str(self.net.config)

        with open('%s/net_info.txt' % self.logs_path, 'w') as fout:
            fout.write(json.dumps(net_info, indent=4) + '\n')

    """ save and load models """

    def save_model(self, checkpoint=None, is_best=False, model_name=None):
        if checkpoint is None:
            checkpoint = {'state_dict': self.net.module.state_dict()}

        if model_name is None:
            model_name = 'checkpoint.pth.tar'

        checkpoint[
            'dataset'] = self.run_config.dataset  # add `dataset` info to the checkpoint
        latest_fname = os.path.join(self.save_path, 'latest.txt')
        model_path = os.path.join(self.save_path, model_name)
        with open(latest_fname, 'w') as fout:
            fout.write(model_path + '\n')
        torch.save(checkpoint, model_path)

        if is_best:
            best_path = os.path.join(self.save_path, 'model_best.pth.tar')
            torch.save({'state_dict': checkpoint['state_dict']}, best_path)

    def load_model(self, model_fname=None):
        latest_fname = os.path.join(self.save_path, 'latest.txt')
        if model_fname is None and os.path.exists(latest_fname):
            with open(latest_fname, 'r') as fin:
                model_fname = fin.readline()
                if model_fname[-1] == '\n':
                    model_fname = model_fname[:-1]
        # noinspection PyBroadException
        try:
            if model_fname is None or not os.path.exists(model_fname):
                model_fname = '%s/checkpoint.pth.tar' % self.save_path
                with open(latest_fname, 'w') as fout:
                    fout.write(model_fname + '\n')
            if self.out_log:
                print("=> loading checkpoint '{}'".format(model_fname))

            if torch.cuda.is_available():
                checkpoint = torch.load(model_fname)
            else:
                checkpoint = torch.load(model_fname, map_location='cpu')

            self.net.module.load_state_dict(checkpoint['state_dict'])
            # set new manual seed
            new_manual_seed = int(time.time())
            torch.manual_seed(new_manual_seed)
            torch.cuda.manual_seed_all(new_manual_seed)
            np.random.seed(new_manual_seed)

            if 'epoch' in checkpoint:
                self.start_epoch = checkpoint['epoch'] + 1
            if 'best_acc' in checkpoint:
                self.best_acc = checkpoint['best_acc']
            if 'optimizer' in checkpoint:
                self.optimizer.load_state_dict(checkpoint['optimizer'])

            if self.out_log:
                print("=> loaded checkpoint '{}'".format(model_fname))
        except Exception:
            if self.out_log:
                print('fail to load checkpoint from %s' % self.save_path)

    def save_config(self, print_info=True):
        """ dump run_config and net_config to the model_folder """
        os.makedirs(self.path, exist_ok=True)
        net_save_path = os.path.join(self.path, 'net.config')
        json.dump(self.net.module.config, open(net_save_path, 'w'), indent=4)
        if print_info:
            print('Network configs dump to %s' % net_save_path)

        run_save_path = os.path.join(self.path, 'run.config')
        json.dump(self.run_config.config, open(run_save_path, 'w'), indent=4)
        if print_info:
            print('Run configs dump to %s' % run_save_path)

    """ train and test """

    def write_log(self, log_str, prefix, should_print=True):
        """ prefix: valid, train, test """
        if prefix in ['valid', 'test']:
            with open(os.path.join(self.logs_path, 'valid_console.txt'),
                      'a') as fout:
                fout.write(log_str + '\n')
                fout.flush()
        if prefix in ['valid', 'test', 'train']:
            with open(os.path.join(self.logs_path, 'train_console.txt'),
                      'a') as fout:
                if prefix in ['valid', 'test']:
                    fout.write('=' * 10)
                fout.write(log_str + '\n')
                fout.flush()
        if prefix in ['prune']:
            with open(os.path.join(self.logs_path, 'prune_console.txt'),
                      'a') as fout:
                if prefix in ['valid', 'test']:
                    fout.write('=' * 10)
                fout.write(log_str + '\n')
                fout.flush()
        if should_print:
            print(log_str)

    def validate(self,
                 is_test=True,
                 net=None,
                 use_train_mode=False,
                 return_top5=False):
        if is_test:
            data_loader = self.run_config.test_loader
        else:
            data_loader = self.run_config.valid_loader

        if net is None:
            net = self.net

        if use_train_mode:
            net.train()
        else:
            net.eval()
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        end = time.time()
        # noinspection PyUnresolvedReferences
        with torch.no_grad():
            for i, data in enumerate(data_loader):
                images, labels = data[0].cuda(non_blocking=True), data[1].cuda(
                    non_blocking=True)
                # images, labels = data[0].cuda(), data[1].cuda()
                # compute output
                output = net(images)
                loss = self.criterion(output, labels)
                # measure accuracy and record loss
                acc1, acc5 = accuracy(output, labels, topk=(1, 5))
                reduced_loss = self.reduce_tensor(loss.data)
                acc1 = self.reduce_tensor(acc1)
                acc5 = self.reduce_tensor(acc5)
                losses.update(reduced_loss, images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % self.run_config.print_frequency == 0 or i + 1 == len(
                        data_loader):
                    if is_test:
                        prefix = 'Test'
                    else:
                        prefix = 'Valid'
                    test_log = prefix + ': [{0}/{1}]\t' \
                                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
                                        'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'. \
                        format(i, len(data_loader) - 1, batch_time=batch_time, loss=losses, top1=top1)
                    if return_top5:
                        test_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(
                            top5=top5)
                    print(test_log)
        self.run_config.valid_loader.reset()
        self.run_config.test_loader.reset()
        if return_top5:
            return losses.avg, top1.avg, top5.avg
        else:
            return losses.avg, top1.avg

    def train_bn(self, epochs=1):
        if self.run_config.local_rank == 0:
            print('training bn')
        for m in self.net.modules():
            if isinstance(m, torch.nn.BatchNorm2d):
                m.running_mean = torch.zeros_like(m.running_mean)
                m.running_var = torch.ones_like(m.running_var)
        self.net.train()
        for i in range(epochs):
            for _, data in enumerate(self.run_config.train_loader):
                images, labels = data[0].cuda(non_blocking=True), data[1].cuda(
                    non_blocking=True)
                output = self.net(images)
                del output, images, labels
        if self.run_config.local_rank == 0:
            print('training bn finished')

    def train_one_epoch(self, adjust_lr_func, train_log_func, epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        # switch to train mode
        self.net.train()

        end = time.time()
        for i, data in enumerate(self.run_config.train_loader):
            data_time.update(time.time() - end)
            new_lr = adjust_lr_func(i)
            images, labels = data[0].cuda(non_blocking=True), data[1].cuda(
                non_blocking=True)
            # compute output
            output = self.net(images)
            if self.run_config.label_smoothing > 0:
                loss = cross_entropy_with_label_smoothing(
                    output, labels, self.run_config.label_smoothing)
            else:
                loss = self.criterion(output, labels)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, labels, topk=(1, 5))
            reduced_loss = self.reduce_tensor(loss.data)
            acc1 = self.reduce_tensor(acc1)
            acc5 = self.reduce_tensor(acc5)
            losses.update(reduced_loss, images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # compute gradient and do SGD step
            self.net.zero_grad()  # or self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            torch.cuda.synchronize()
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if (i % self.run_config.print_frequency == 0
                    or i + 1 == len(self.run_config.train_loader)
                ) and self.run_config.local_rank == 0:
                batch_log = train_log_func(i, batch_time, data_time, losses,
                                           top1, top5, new_lr)
                self.write_log(batch_log, 'train')
        return top1, top5

    def train(self, print_top5=False):
        def train_log_func(epoch_, i, batch_time, data_time, losses, top1,
                           top5, lr):
            batch_log = 'Train [{0}][{1}/{2}]\t' \
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                        'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                        'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
                        'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'. \
                format(epoch_ + 1, i, len(self.run_config.train_loader) - 1,
                       batch_time=batch_time, data_time=data_time, losses=losses, top1=top1)
            if print_top5:
                batch_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(
                    top5=top5)
            batch_log += '\tlr {lr:.5f}'.format(lr=lr)
            return batch_log

        for epoch in range(self.start_epoch, self.run_config.n_epochs):
            if self.run_config.local_rank == 0:
                print('\n', '-' * 30, 'Train epoch: %d' % (epoch + 1),
                      '-' * 30, '\n')

            end = time.time()
            train_top1, train_top5 = self.train_one_epoch(
                lambda i: self.run_config.adjust_learning_rate(
                    self.optimizer, epoch, i, len(self.run_config.train_loader)
                ), lambda i, batch_time, data_time, losses, top1, top5, new_lr:
                train_log_func(epoch, i, batch_time, data_time, losses, top1,
                               top5, new_lr), epoch)
            time_per_epoch = time.time() - end
            seconds_left = int(
                (self.run_config.n_epochs - epoch - 1) * time_per_epoch)
            if self.run_config.local_rank == 0:
                print('Time per epoch: %s, Est. complete in: %s' %
                      (str(timedelta(seconds=time_per_epoch)),
                       str(timedelta(seconds=seconds_left))))

            if (epoch + 1) % self.run_config.validation_frequency == 0:
                val_loss, val_acc, val_acc5 = self.validate(is_test=False,
                                                            return_top5=True)
                is_best = val_acc > self.best_acc
                self.best_acc = max(self.best_acc, val_acc)
                val_log = 'Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f} ({4:.3f})'. \
                    format(epoch + 1, self.run_config.n_epochs, val_loss, val_acc, self.best_acc)
                if print_top5:
                    val_log += '\ttop-5 acc {0:.3f}\tTrain top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}'. \
                        format(val_acc5, top1=train_top1, top5=train_top5)
                else:
                    val_log += '\tTrain top-1 {top1.avg:.3f}'.format(
                        top1=train_top1)
                if self.run_config.local_rank == 0:
                    self.write_log(val_log, 'valid')
            else:
                is_best = False
            if self.run_config.local_rank == 0:
                self.save_model(
                    {
                        'epoch': epoch,
                        'best_acc': self.best_acc,
                        'optimizer': self.optimizer.state_dict(),
                        'state_dict': self.net.state_dict(),
                    },
                    is_best=is_best)
            self.run_config.train_loader.reset()
            self.run_config.valid_loader.reset()
            self.run_config.test_loader.reset()

    def reduce_tensor(self, tensor):
        rt = tensor.clone()
        dist.all_reduce(rt, op=dist.ReduceOp.SUM)
        rt /= self.run_config.world_size
        return rt
Пример #4
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--from_pretrained",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--output_dir",
        default="save",
        type=str,
        help=
        "The output directory where the model checkpoints will be written.",
    )
    parser.add_argument(
        "--config_file",
        default="config/bert_config.json",
        type=str,
        help="The config file which specified the model details.",
    )
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument(
        "--num_train_epochs",
        default=20,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument("--no_cuda",
                        action="store_true",
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--do_lower_case",
        default=True,
        type=bool,
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models.",
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--seed",
                        type=int,
                        default=0,
                        help="random seed for initialization")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument("--num_workers",
                        type=int,
                        default=16,
                        help="Number of workers in the dataloader.")
    parser.add_argument(
        "--save_name",
        default='',
        type=str,
        help="save name for training.",
    )
    parser.add_argument("--use_chunk",
                        default=0,
                        type=float,
                        help="whether use chunck for parallel training.")
    parser.add_argument("--in_memory",
                        default=False,
                        type=bool,
                        help="whether use chunck for parallel training.")
    parser.add_argument("--optimizer",
                        default='BertAdam',
                        type=str,
                        help="whether use chunck for parallel training.")
    parser.add_argument("--tasks",
                        default='',
                        type=str,
                        help="1-2-3... training task separate by -")
    parser.add_argument(
        "--freeze",
        default=-1,
        type=int,
        help="till which layer of textual stream of vilbert need to fixed.")
    parser.add_argument("--vision_scratch",
                        action="store_true",
                        help="whether pre-trained the image or not.")
    parser.add_argument("--evaluation_interval",
                        default=1,
                        type=int,
                        help="evaluate very n epoch.")
    parser.add_argument("--lr_scheduler",
                        default='mannul',
                        type=str,
                        help="whether use learning rate scheduler.")
    parser.add_argument("--baseline",
                        action="store_true",
                        help="whether use single stream baseline.")
    parser.add_argument("--compact",
                        action="store_true",
                        help="whether use compact vilbert model.")
    parser.add_argument("--debug",
                        action="store_true",
                        help="whether in debug mode.")
    parser.add_argument(
        "--tensorboard_dir",
        default="tensorboard_log",
        type=str,
        help="The output directory where tensorboard log will be written.",
    )
    parser.add_argument(
        "--batch_size",
        default=-1,
        type=int,
        help="Custom Batch size for task.",
    )
    parser.add_argument(
        "--data_root",
        default="",
        type=str,
        help="The data root of the task.",
    )
    args = parser.parse_args()
    with open('vlbert_tasks.yml', 'r') as f:
        task_cfg = edict(yaml.load(f))

    # random.seed(args.seed)
    # np.random.seed(args.seed)
    # torch.manual_seed(args.seed)

    if args.baseline:
        from pytorch_pretrained_bert.modeling import BertConfig
        from vilbert.basebert import BaseBertForVLTasks
    elif args.compact:
        from vilbert.vilbert_compact import BertConfig
        from vilbert.vilbert_compact import VILBertForVLTasks
    else:
        from vilbert.vilbert import BertConfig
        from vilbert.vilbert import VILBertForVLTasks

    task_names = []
    task_lr = []
    for i, task_id in enumerate(args.tasks.split('-')):
        task = 'TASK' + task_id
        name = task_cfg[task]['name']
        task_names.append(name)
        task_lr.append(task_cfg[task]['lr'])

    base_lr = min(task_lr)
    loss_scale = {}
    for i, task_id in enumerate(args.tasks.split('-')):
        task = 'TASK' + task_id
        loss_scale[task] = task_lr[i] / base_lr

    if args.save_name:
        prefix = '-' + args.save_name
    else:
        prefix = ''
    timeStamp = '-'.join(task_names) + '_' + args.config_file.split(
        '/')[1].split('.')[0] + prefix
    savePath = os.path.join(args.output_dir, timeStamp)
    logPath = os.path.join(args.tensorboard_dir, timeStamp)

    bert_weight_name = json.load(
        open("config/" + args.bert_model + "_weight_name.json", "r"))

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend="nccl")

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    default_gpu = False
    if dist.is_available() and args.local_rank != -1:
        rank = dist.get_rank()
        if rank == 0:
            default_gpu = True
    else:
        default_gpu = True

    if default_gpu:
        if not os.path.exists(savePath):
            os.makedirs(savePath)

    config = BertConfig.from_json_file(args.config_file)
    if default_gpu:
        # save all the hidden parameters.
        with open(os.path.join(savePath, 'command.txt'), 'w') as f:
            print(args, file=f)  # Python 3.x
            print('\n', file=f)
            print(config, file=f)

    if args.batch_size != -1:
        for i, task_id in enumerate(args.tasks.split('-')):
            task = 'TASK' + task_id
            task_cfg[task]['batch_size'] = args.batch_size

    if args.data_root != "":
        for i, task_id in enumerate(args.tasks.split('-')):
            data_root = args.data_root
            task = 'TASK' + task_id
            task_cfg[task]['dataroot'] = data_root
            task_cfg[task]['features_h5path1'] = os.path.join(
                data_root, task_cfg[task]['features_h5path1'].split('/')[-1])
            task_cfg[task]['features_h5path2'] = os.path.join(
                data_root, task_cfg[task]['features_h5path2'].split('/')[-1])
            task_cfg[task]['train_annotations_jsonpath'] = os.path.join(
                data_root,
                task_cfg[task]['train_annotations_jsonpath'].split('/')[-1])
            task_cfg[task]['val_annotations_jsonpath'] = os.path.join(
                data_root,
                task_cfg[task]['val_annotations_jsonpath'].split('/')[-1])

    # Done it for VCR Dataset only, need to put this train_100.jsonl for other datasets
    if args.debug:
        for i, task_id in enumerate(args.tasks.split('-')):
            task = 'TASK' + task_id
            task_cfg[task]['train_annotations_jsonpath'] = '/'.join(
                task_cfg[task]['train_annotations_jsonpath'].split('/')[:-1] +
                ['train_100.jsonl'])
            task_cfg[task]['val_annotations_jsonpath'] = '/'.join(
                task_cfg[task]['val_annotations_jsonpath'].split('/')[:-1] +
                ['val_100.jsonl'])
            task_cfg[task]['batch_size'] = 2

    # Have added args.debug to only VCR Datasets (vcr_dataset.py) will need to add it to other dataset too.
    task_batch_size, task_num_iters, task_ids, task_datasets_train, task_datasets_val, \
            task_dataloader_train, task_dataloader_val = LoadDatasets(args, task_cfg, args.tasks.split('-'), args.debug)

    tbLogger = utils.tbLogger(logPath, savePath, task_names, task_ids,
                              task_num_iters, args.gradient_accumulation_steps)

    # if n_gpu > 0:
    # torch.cuda.manual_seed_all(args.seed)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    num_train_optimization_steps = max(task_num_iters.values(
    )) * args.num_train_epochs // args.gradient_accumulation_steps
    num_labels = max(
        [dataset.num_labels for dataset in task_datasets_train.values()])

    task_start_iter = {}
    task_interval = {}
    for task_id, num_iter in task_num_iters.items():
        task_start_iter[task_id] = num_train_optimization_steps - (
            task_cfg[task]['num_epoch'] * num_iter //
            args.gradient_accumulation_steps)
        task_interval[task_id] = num_train_optimization_steps // (
            task_cfg[task]['num_epoch'] * num_iter //
            args.gradient_accumulation_steps)

    if args.baseline:
        model = BaseBertForVLTasks.from_pretrained(args.from_pretrained,
                                                   config,
                                                   num_labels=num_labels,
                                                   default_gpu=default_gpu)
    else:
        model = VILBertForVLTasks.from_pretrained(args.from_pretrained,
                                                  config,
                                                  num_labels=num_labels,
                                                  default_gpu=default_gpu)

    task_losses = LoadLosses(args, task_cfg, args.tasks.split('-'))
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model, delay_allreduce=True)

    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]

    if args.freeze != -1:
        bert_weight_name_filtered = []
        for name in bert_weight_name:
            if 'embeddings' in name:
                bert_weight_name_filtered.append(name)
            elif 'encoder' in name:
                layer_num = name.split('.')[2]
                if int(layer_num) <= args.freeze:
                    bert_weight_name_filtered.append(name)

        optimizer_grouped_parameters = []
        for key, value in dict(model.named_parameters()).items():
            if key[12:] in bert_weight_name_filtered:
                value.requires_grad = False

        if default_gpu:
            print("filtered weight")
            print(bert_weight_name_filtered)

    optimizer_grouped_parameters = []
    lr = args.learning_rate
    for key, value in dict(model.named_parameters()).items():
        if value.requires_grad:
            if 'vil_prediction' in key:
                # if args.learning_rate <= 2e-5:
                lr = 1e-4
            else:
                if args.vision_scratch:
                    if key[12:] in bert_weight_name:
                        lr = args.learning_rate
                    else:
                        lr = 1e-4
                else:
                    lr = args.learning_rate
            if any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [{
                    "params": [value],
                    "lr": lr,
                    "weight_decay": 0.01
                }]
            if not any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [{
                    "params": [value],
                    "lr": lr,
                    "weight_decay": 0.0
                }]

    if default_gpu:
        print(len(list(model.named_parameters())),
              len(optimizer_grouped_parameters))

    max_num_iter = max(task_num_iters.values())
    max_batch_size = max(task_batch_size.values())

    if args.optimizer == 'BertAdam':
        optimizer = BertAdam(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            warmup=args.warmup_proportion,
            t_total=num_train_optimization_steps,
            schedule='warmup_constant',
        )
    elif args.optimizer == 'Adam':
        optimizer = Adam(
            optimizer_grouped_parameters,
            lr=base_lr,
            warmup=args.warmup_proportion,
            t_total=num_train_optimization_steps,
            schedule='warmup_constant',
        )
    elif args.optimizer == 'Adamax':
        optimizer = Adamax(
            optimizer_grouped_parameters,
            lr=base_lr,
            warmup=args.warmup_proportion,
            t_total=num_train_optimization_steps,
            schedule='warmup_constant',
        )

    if args.lr_scheduler == 'automatic':
        lr_scheduler = ReduceLROnPlateau(optimizer, \
                        mode='max',
                        factor=0.2,
                        patience=1,
                        cooldown=1,
                        threshold=0.001)
    elif args.lr_scheduler == 'mannul':
        lr_reduce_list = np.array([12, 16])

        # lr_reduce_list = np.array([6, 8, 10])
        def lr_lambda_fun(epoch):
            return pow(0.1, np.sum(lr_reduce_list <= epoch))

        lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda_fun)

    if default_gpu:
        print("***** Running training *****")
        print("  Num Iters: ", task_num_iters)
        print("  Batch size: ", task_batch_size)
        print("  Num steps: %d" % num_train_optimization_steps)

    startIterID = 0
    # initialize the data iteration.
    task_iter_train = {name: None for name in task_ids}
    task_count = {name: 0 for name in task_ids}
    for epochId in tqdm(range(args.num_train_epochs), desc="Epoch"):
        model.train()
        for step in range(max_num_iter):
            iterId = startIterID + step + (epochId * max_num_iter)
            for task_id in task_ids:
                if iterId >= task_start_iter[task_id]:
                    # if iterId % task_interval[task_id] == 0:
                    loss, score = ForwardModelsTrain(args, task_cfg, device,
                                                     task_id, task_count,
                                                     task_iter_train,
                                                     task_dataloader_train,
                                                     model, task_losses,
                                                     task_start_iter)
                    loss = loss * loss_scale[task_id]
                    if args.gradient_accumulation_steps > 1:
                        loss = loss / args.gradient_accumulation_steps

                    loss.backward()
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        optimizer.step()
                        model.zero_grad()

                        if default_gpu:
                            tbLogger.step_train(epochId, iterId, float(loss),
                                                float(score),
                                                optimizer.show_lr(), task_id,
                                                'train')

            if step % (20 * args.gradient_accumulation_steps
                       ) == 0 and step != 0 and default_gpu:
                tbLogger.showLossTrain()

        model.eval()
        # when run evaluate, we run each task sequentially.
        for task_id in task_ids:
            for i, batch in enumerate(task_dataloader_val[task_id]):
                loss, score, batch_size = ForwardModelsVal(
                    args, task_cfg, device, task_id, batch, model, task_losses)
                tbLogger.step_val(epochId, float(loss), float(score), task_id,
                                  batch_size, 'val')
                if default_gpu:
                    sys.stdout.write('%d/%d\r' %
                                     (i, len(task_dataloader_val[task_id])))
                    sys.stdout.flush()

        ave_score = tbLogger.showLossVal()
        if args.lr_scheduler == 'automatic':
            lr_scheduler.step(ave_score)
            logger.info("best average score is %3f" % lr_scheduler.best)
        else:
            lr_scheduler.step()

        if default_gpu:
            # Save a trained model
            logger.info("** ** * Saving fine - tuned model on " + logPath +
                        "** ** * ")
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Only save the model it-self

            if not os.path.exists(savePath):
                os.makedirs(savePath)
            output_model_file = os.path.join(
                savePath, "pytorch_model_" + str(epochId) + ".bin")
            torch.save(model_to_save.state_dict(), output_model_file)

    tbLogger.txt_close()
Пример #5
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch FixMatch Training')
    parser.add_argument('--gpu-id', default='0', type=int,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--num-workers', type=int, default=4,
                        help='number of workers')
    parser.add_argument('--dataset', default='cifar10', type=str,
                        choices=['cifar10', 'cifar100'],
                        help='dataset name')
    parser.add_argument('--num-labeled', type=int, default=4000,
                        help='number of labeled data')
    parser.add_argument("--expand-labels", action="store_true",
                        help="expand labels to fit eval steps")
    parser.add_argument('--arch', default='wideresnet', type=str,
                        choices=['wideresnet', 'resnext'],
                        help='dataset name')
    parser.add_argument('--total-steps', default=2**20, type=int,
                        help='number of total steps to run')
    parser.add_argument('--eval-step', default=1024, type=int,
                        help='number of eval steps to run')
    parser.add_argument('--start-epoch', default=0, type=int,
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--batch-size', default=64, type=int,
                        help='train batchsize')
    parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
                        help='initial learning rate')
    parser.add_argument('--warmup', default=0, type=float,
                        help='warmup epochs (unlabeled data based)')
    parser.add_argument('--wdecay', default=5e-4, type=float,
                        help='weight decay')
    parser.add_argument('--nesterov', action='store_true', default=True,
                        help='use nesterov momentum')
    parser.add_argument('--use-ema', action='store_true', default=True,
                        help='use EMA model')
    parser.add_argument('--ema-decay', default=0.999, type=float,
                        help='EMA decay rate')
    parser.add_argument('--mu', default=7, type=int,
                        help='coefficient of unlabeled batch size')
    parser.add_argument('--lambda-u', default=1, type=float,
                        help='coefficient of unlabeled loss')
    parser.add_argument('--T', default=1, type=float,
                        help='pseudo label temperature')
    parser.add_argument('--threshold', default=0.95, type=float,
                        help='pseudo label threshold')
    parser.add_argument('--out', default='result',
                        help='directory to output the result')
    parser.add_argument('--resume', default='', type=str,
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--seed', default=None, type=int,
                        help="random seed")
    parser.add_argument("--amp", action="store_true",
                        help="use 16-bit (mixed) precision through NVIDIA apex AMP")
    parser.add_argument("--opt_level", type=str, default="O0",
                        help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--no-progress', action='store_true',
                        help="don't use progress bar")

    args = parser.parse_args()
    global best_acc

    def create_model(args):
        if args.arch == 'wideresnet':
            import models.wideresnet as models
            model = models.build_wideresnet(depth=args.model_depth,
                                            widen_factor=args.model_width,
                                            dropout=0,
                                            num_classes=args.num_classes)
        elif args.arch == 'resnext':
            import models.resnext as models
            model = models.build_resnext(cardinality=args.model_cardinality,
                                         depth=args.model_depth,
                                         width=args.model_width,
                                         num_classes=args.num_classes)
        logger.info("Total params: {:.2f}M".format(
            sum(p.numel() for p in model.parameters())/1e6))
        return model

    if args.local_rank == -1:
        device = torch.device('cuda', args.gpu_id)
        args.world_size = 1
        args.n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device('cuda', args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.world_size = torch.distributed.get_world_size()
        args.n_gpu = 1

    args.device = device

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.warning(
        f"Process rank: {args.local_rank}, "
        f"device: {args.device}, "
        f"n_gpu: {args.n_gpu}, "
        f"distributed training: {bool(args.local_rank != -1)}, "
        f"16-bits training: {args.amp}",)

    logger.info(dict(args._get_kwargs()))

    if args.seed is not None:
        set_seed(args)

    if args.local_rank in [-1, 0]:
        os.makedirs(args.out, exist_ok=True)
        writer = SummaryWriter(args.out)
    else:
        writer = None

    if args.dataset == 'cifar10':
        args.num_classes = 10
        if args.arch == 'wideresnet':
            args.model_depth = 28
            args.model_width = 2
        elif args.arch == 'resnext':
            args.model_cardinality = 4
            args.model_depth = 28
            args.model_width = 4

    elif args.dataset == 'cifar100':
        args.num_classes = 100
        if args.arch == 'wideresnet':
            args.model_depth = 28
            args.model_width = 8
        elif args.arch == 'resnext':
            args.model_cardinality = 8
            args.model_depth = 29
            args.model_width = 64

    labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset](
        args, os.path.join(root_dir, 'data'))

    train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler

    labeled_trainloader = DataLoader(
        labeled_dataset,
        sampler=train_sampler(labeled_dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        drop_last=True)

    unlabeled_trainloader = DataLoader(
        unlabeled_dataset,
        sampler=train_sampler(unlabeled_dataset),
        batch_size=args.batch_size*args.mu,
        num_workers=args.num_workers,
        drop_last=True)

    test_loader = DataLoader(
        test_dataset,
        sampler=SequentialSampler(test_dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers)

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()

    model = create_model(args)

    if args.local_rank == 0:
        torch.distributed.barrier()

    model.to(args.device)

    no_decay = ['bias', 'bn']
    grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(
            nd in n for nd in no_decay)], 'weight_decay': args.wdecay},
        {'params': [p for n, p in model.named_parameters() if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = optim.SGD(grouped_parameters, lr=args.lr,
                          momentum=0.9, nesterov=args.nesterov)

    args.epochs = math.ceil(args.total_steps / args.eval_step)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, args.warmup, args.total_steps)

    if args.use_ema:
        from models.ema import ModelEMA
        ema_model = ModelEMA(args, model, args.ema_decay)

    args.start_epoch = 0

    if args.resume:
        logger.info("==> Resuming from checkpoint..")
        assert os.path.isfile(
            args.resume), "Error: no checkpoint directory found!"
        args.out = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        if args.use_ema:
            ema_model.ema.load_state_dict(checkpoint['ema_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])

    # if args.amp:
    #     from apex import amp
    #     model, optimizer = amp.initialize(
    #         model, optimizer, opt_level=args.opt_level)

    if args.local_rank != -1:
        model = DistributedDataParallel(model)
        # model = torch.nn.parallel.DistributedDataParallel(
        #     model, device_ids=[args.local_rank],
        #     output_device=args.local_rank, find_unused_parameters=True)

    logger.info("***** Running training *****")
    logger.info(f"  Task = {args.dataset}@{args.num_labeled}")
    logger.info(f"  Num Epochs = {args.epochs}")
    logger.info(f"  Batch size per GPU = {args.batch_size}")
    logger.info(
        f"  Total train batch size = {args.batch_size*args.world_size}")
    logger.info(f"  Total optimization steps = {args.total_steps}")

    model.zero_grad()
    train(args, labeled_trainloader, unlabeled_trainloader, test_loader,
          model, optimizer, ema_model, scheduler, writer)
Пример #6
0
    def forward(self, input):
        return (input * self.a) * self.b


model = DDP(Model(), message_size=1)
# model = DDP(Model(), delay_allreduce=True)

x = torch.cuda.FloatTensor(4096 * 4096)

passed = True
for i in range(10):
    x.fill_(
        i +
        args.local_rank)  # fill x with new values every iteration for sanity
    model.zero_grad()
    out = model(x)
    loss = out.sum()
    # torch.cuda.nvtx.range_push("backward")
    loss.backward()
    # torch.cuda.nvtx.range_pop()

    # torch.cuda.nvtx.range_push("synchronize() + info")
    # torch.cuda.synchronize()
    print("i = {}".format(i))

    def info(name, param, val):
        expected = val * 4096 * 4096 * (2. * i + 1) / 2.
        actual = param.grad.data.sum().item()
        print(name + ": grad.data_ptr() = {}, expected sum {}, got {}".format(
            param.grad.data_ptr(), expected, actual))
Пример #7
0
def train(args, train_dataset, model):
    """ Train the model """
    args.train_batch_size = args.per_gpu_train_batch_size // args.gradient_accumulation_steps
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    t_total = len(train_dataloader
                  ) // args.gradient_accumulation_steps * args.num_train_epochs

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=t_total *
                                     args.warmup_proportion,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        if args.fp16_opt_level == "O2":
            keep_batchnorm_fp32 = False
        else:
            keep_batchnorm_fp32 = True
        model, optimizer = amp.initialize(
            model,
            optimizer,
            opt_level=args.fp16_opt_level,
            keep_batchnorm_fp32=keep_batchnorm_fp32)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(
            model,
            message_size=250000000,
            gradient_predivide_factor=torch.distributed.get_world_size())

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs = 0
    model.zero_grad()
    model.train()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Train(XX Epoch) Step(X/X) (loss=X.X)",
                              disable=args.local_rank not in [-1, 0],
                              leave=True,
                              position=0)
        for step, batch in enumerate(epoch_iterator):
            batch = tuple(t.to(args.device)
                          for t in batch)  # multi-gpu does scattering it-self
            input_ids, input_mask, segment_ids, start_positions, end_positions = batch
            outputs = model(input_ids, segment_ids, input_mask,
                            start_positions, end_positions)
            loss = outputs  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                scheduler.step()  # Update learning rate schedule\
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                epoch_iterator.set_description(
                    "Train(%d Epoch) Step(%d / %d) (loss=%5.5f)" %
                    (_, global_step, t_total, loss.item()))

        if args.local_rank in [-1, 0]:
            if epochs < 10:
                n_epochs = '0' + str(epochs)
            else:
                n_epochs = str(epochs)
            model_checkpoint = "korquad_{0}_{1}_{2}_{3}_{4}_{5}_{6}.bin".format(
                args.learning_rate, args.train_batch_size, n_epochs,
                int(args.num_train_epochs), args.eda_type, args.num_aug,
                args.alpha)
            logger.info(model_checkpoint)
            output_model_file = os.path.join(args.output_dir, model_checkpoint)
            if args.n_gpu > 1 or args.local_rank != -1:
                logger.info("** ** * Saving file * ** ** (module)")
                torch.save(model.module.state_dict(), output_model_file)
            else:
                logger.info("** ** * Saving file * ** **")
                torch.save(model.state_dict(), output_model_file)
        epochs += 1
    logger.info("Training End!!!")
def main():
    # os.environ['C UDA_VISIBLE_DEVICES'] = "0,1"
    batch_size = 64
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--from_pretrained",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--output_dir",
        default="save",
        type=str,
        help=
        "The output directory where the model checkpoints will be written.",
    )
    parser.add_argument(
        "--config_file",
        default="config/bert_base_6layer_6conect.json",
        type=str,
        help="The config file which specified the model details.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=20,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--train_iter_multiplier",
        default=1.0,
        type=float,
        help="multiplier for the multi-task training.",
    )
    parser.add_argument(
        "--train_iter_gap",
        default=4,
        type=int,
        help=
        "forward every n iteration is the validation score is not improving over the last 3 epoch, -1 means will stop",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument("--no_cuda",
                        action="store_true",
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--do_lower_case",
        default=True,
        type=bool,
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models.",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="local_rank for distributed training on gpus",
    )
    parser.add_argument("--seed",
                        type=int,
                        default=0,
                        help="random seed for initialization")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=16,
        help="Number of workers in the dataloader.",
    )
    parser.add_argument("--save_name",
                        default="",
                        type=str,
                        help="save name for training.")
    parser.add_argument(
        "--in_memory",
        default=False,
        type=bool,
        help="whether use chunck for parallel training.",
    )
    parser.add_argument("--optim",
                        default="AdamW",
                        type=str,
                        help="what to use for the optimization.")
    parser.add_argument("--tasks",
                        default="0",
                        type=str,
                        help="discourse : TASK0")
    parser.add_argument(
        "--freeze",
        default=-1,
        type=int,
        help="till which layer of textual stream of vilbert need to fixed.",
    )
    parser.add_argument(
        "--vision_scratch",
        action="store_true",
        help="whether pre-trained the image or not.",
    )
    parser.add_argument("--evaluation_interval",
                        default=1,
                        type=int,
                        help="evaluate very n epoch.")
    parser.add_argument(
        "--lr_scheduler",
        default="mannul",
        type=str,
        help="whether use learning rate scheduler.",
    )
    parser.add_argument("--baseline",
                        action="store_true",
                        help="whether use single stream baseline.")
    parser.add_argument("--resume_file",
                        default="",
                        type=str,
                        help="Resume from checkpoint")
    parser.add_argument(
        "--dynamic_attention",
        action="store_true",
        help="whether use dynamic attention.",
    )
    parser.add_argument(
        "--clean_train_sets",
        default=True,
        type=bool,
        help="whether clean train sets for multitask data.",
    )
    parser.add_argument(
        "--visual_target",
        default=0,
        type=int,
        help="which target to use for visual branch. \
        0: soft label, \
        1: regress the feature, \
        2: NCE loss.",
    )
    parser.add_argument(
        "--task_specific_tokens",
        action="store_true",
        default=False,
        help="whether to use task specific tokens for the multi-task learning.",
    )

    # todo
    args = parser.parse_args()
    with open("vilbert_tasks.yml", "r") as f:
        task_cfg = edict(yaml.safe_load(f))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if args.baseline:
        from pytorch_transformers.modeling_bert import BertConfig
        from vilbert.basebert import BaseBertForVLTasks
    else:
        from vilbert.vilbert import BertConfig
        from vilbert.vilbert import VILBertForVLTasks

    task_names = []
    task_lr = []
    task_id = 1
    for i, task_id in enumerate(args.tasks.split("-")):
        task_id = str(1)
        task = "TASK" + task_id
        name = task_cfg[task]["name"]
        task_names.append(name)
        task_lr.append(task_cfg[task]["lr"])
    base_lr = min(task_lr)
    loss_scale = {}
    for i, task_id in enumerate(args.tasks.split("-")):
        task = "TASK" + task_id
        loss_scale[task] = task_lr[i] / base_lr

    if args.save_name:
        prefix = "-" + args.save_name
    else:
        prefix = ""
    timeStamp = ("-".join("discourse") + "_" +
                 args.config_file.split("/")[1].split(".")[0] + prefix)
    savePath = os.path.join(args.output_dir, timeStamp)

    bert_weight_name = json.load(
        open("config/" + args.bert_model + "_weight_name.json", "r"))

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        torch.distributed.init_process_group(backend="nccl")

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    default_gpu = False
    if dist.is_available() and args.local_rank != -1:
        rank = dist.get_rank()
        if rank == 0:
            default_gpu = True
    else:
        default_gpu = True

    if default_gpu:
        if not os.path.exists(savePath):
            os.makedirs(savePath)

    config = BertConfig.from_json_file(args.config_file)
    if default_gpu:
        # save all the hidden parameters.
        with open(os.path.join(savePath, "command.txt"), "w") as f:
            print(args, file=f)  # Python 3.x
            print("\n", file=f)
            print(config, file=f)

    # task_batch_size, task_num_iters, task_ids, task_datasets_train, task_datasets_val, task_dataloader_train, task_dataloader_val = LoadDatasets(
    #     args, task_cfg, args.tasks.split("-"),'train'
    # )
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    labels = [
        "Visible", 'Subjective', 'Action', 'Story', 'Meta', 'Irrelevant',
        'Other'
    ]
    train_dataset = DiscourseRelationDataset(
        labels,
        task_cfg[task]["dataroot"],
        tokenizer,
        args.bert_model,
        task_cfg[task]["max_seq_length"],
        encoding="utf-8",
        visual_target=0,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        cache=5000,
        drop_last=False,
        cuda=False,
        objective=0,
        visualization=False,
    )

    train_sampler = RandomSampler(train_dataset)

    train_loader = DataLoader(
        train_dataset,
        sampler=train_sampler,
        batch_size=batch_size,
        num_workers=0,
        pin_memory=True,
    )
    # for i in train_loader:
    #     print("hello")
    # todo task_ids , task_num_tiers
    task_ids = ['TASK0']
    task_num_iters = [100]
    task_batch_size = task_cfg['TASK0']["batch_size"]

    print("task_batch_size")
    print(task_batch_size)
    logdir = os.path.join(savePath, "logs")
    tbLogger = utils.tbLogger(
        logdir,
        savePath,
        task_names,
        task_ids,
        task_num_iters,
        args.gradient_accumulation_steps,
    )

    if args.visual_target == 0:
        config.v_target_size = 1601
        config.visual_target = args.visual_target
    else:
        config.v_target_size = 2048
        config.visual_target = args.visual_target

    if args.task_specific_tokens:
        print("*********** config.task_specific_tokens = True ************")
        config.task_specific_tokens = True

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_ave_iter = {}
    task_stop_controller = {}
    # for task_id, num_iter in task_num_iters.items():
    #     task_ave_iter[task_id] = int(
    #         task_cfg[task]["num_epoch"]
    #         * num_iter
    #         * args.train_iter_multiplier
    #         / args.num_train_epochs
    #     )
    #     task_stop_controller[task_id] = utils.MultiTaskStopOnPlateau(
    #         mode="max",
    #         patience=1,
    #         continue_threshold=0.005,
    #         cooldown=1,
    #         threshold=0.001,
    #     )

    # task_ave_iter_list = sorted(task_ave_iter.values())
    # median_num_iter = task_ave_iter_list[-1]
    # num_train_optimization_steps = (
    #     median_num_iter * args.num_train_epochs // args.gradient_accumulation_steps
    # )
    # num_labels = max([dataset.num_labels for dataset in task_datasets_train.values()])

    # num_train_optimization_steps = int(
    #     train_dataset.num_dataset
    #     / task_batch_size
    #     / args.gradient_accumulation_steps
    # ) * (args.num_train_epochs - args.start_epoch)

    # num_train_optimization_steps = int(
    #     train_dataset.num_dataset
    #     / task_batch_size
    #     / args.gradient_accumulation_steps
    # ) * (args.num_train_epochs - args.start_epoch)
    num_train_optimization_steps = 10
    num_labels = len(labels)
    if args.dynamic_attention:
        config.dynamic_attention = True
    if "roberta" in args.bert_model:
        config.model = "roberta"

    if args.baseline:
        model = BaseBertForVLTasks.from_pretrained(
            args.from_pretrained,
            config=config,
            num_labels=num_labels,
            default_gpu=default_gpu,
        )
    else:
        model = VILBertForVLTasks.from_pretrained(
            args.from_pretrained,
            config=config,
            num_labels=num_labels,
            default_gpu=default_gpu,
        )
    model.double()
    model = model.to(device)
    task_losses = LoadLosses(args, task_cfg, args.tasks.split("-"))

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]

    if args.freeze != -1:
        bert_weight_name_filtered = []
        for name in bert_weight_name:
            if "embeddings" in name:
                bert_weight_name_filtered.append(name)
            elif "encoder" in name:
                layer_num = name.split(".")[2]
                if int(layer_num) <= args.freeze:
                    bert_weight_name_filtered.append(name)

        optimizer_grouped_parameters = []
        for key, value in dict(model.named_parameters()).items():
            if key[12:] in bert_weight_name_filtered:
                value.requires_grad = False

        if default_gpu:
            print("filtered weight")
            print(bert_weight_name_filtered)

    optimizer_grouped_parameters = []
    for key, value in dict(model.named_parameters()).items():
        if value.requires_grad:
            if "vil_" in key:
                lr = 1e-4
            else:
                if args.vision_scratch:
                    if key[12:] in bert_weight_name:
                        lr = base_lr
                    else:
                        lr = 1e-4
                else:
                    lr = base_lr
            if any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [{
                    "params": [value],
                    "lr": lr,
                    "weight_decay": 0.0
                }]
            if not any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [{
                    "params": [value],
                    "lr": lr,
                    "weight_decay": 0.01
                }]

    if default_gpu:
        print(len(list(model.named_parameters())),
              len(optimizer_grouped_parameters))

    if args.optim == "AdamW":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=base_lr,
                          correct_bias=False,
                          weight_decay=1e-4)
    elif args.optim == "RAdam":
        optimizer = RAdam(optimizer_grouped_parameters,
                          lr=base_lr,
                          weight_decay=1e-4)

    # warmpu_steps = args.warmup_proportion * num_train_optimization_steps

    # if args.lr_scheduler == "warmup_linear":
    #     warmup_scheduler = WarmupLinearSchedule(
    #         optimizer, warmup_steps=warmpu_steps, t_total=num_train_optimization_steps
    #     )
    # else:
    #     warmup_scheduler = WarmupConstantSchedule(optimizer, warmup_steps=warmpu_steps)
    #
    # lr_reduce_list = np.array([5, 7])
    # if args.lr_scheduler == "automatic":
    #     lr_scheduler = ReduceLROnPlateau(
    #         optimizer, mode="max", factor=0.2, patience=1, cooldown=1, threshold=0.001
    #     )
    # elif args.lr_scheduler == "cosine":
    #     lr_scheduler = CosineAnnealingLR(
    #         optimizer, T_max=median_num_iter * args.num_train_epochs
    #     )
    # elif args.lr_scheduler == "cosine_warm":
    #     lr_scheduler = CosineAnnealingWarmRestarts(
    #         # optimizer, T_0=median_num_iter * args.num_train_epochs
    #     )
    # elif args.lr_scheduler == "mannul":
    #
    #     def lr_lambda_fun(epoch):
    #         return pow(0.2, np.sum(lr_reduce_list <= epoch))
    #
    #     lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda_fun)

    startIterID = 0
    global_step = 0
    start_epoch = 0

    if args.resume_file != "" and os.path.exists(args.resume_file):
        checkpoint = torch.load(args.resume_file, map_location="cpu")
        new_dict = {}
        for attr in checkpoint["model_state_dict"]:
            if attr.startswith("module."):
                new_dict[attr.replace(
                    "module.", "", 1)] = checkpoint["model_state_dict"][attr]
            else:
                new_dict[attr] = checkpoint["model_state_dict"][attr]
        model.load_state_dict(new_dict)
        # warmup_scheduler.load_state_dict(checkpoint["warmup_scheduler_state_dict"])
        # lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        global_step = checkpoint["global_step"]
        start_epoch = int(checkpoint["epoch_id"]) + 1
        task_stop_controller = checkpoint["task_stop_controller"]
        tbLogger = checkpoint["tb_logger"]
        del checkpoint

    model.to(device)

    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.cuda()

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model, delay_allreduce=True)

    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if default_gpu:
        print("***** Running training *****")
        print("  Num Iters: ", task_num_iters)
        print("  Batch size: ", batch_size)
        print("  Num steps: %d" % num_train_optimization_steps)

    task_iter_train = {name: None for name in task_ids}
    task_count = {name: 0 for name in task_ids}
    # for epochId in tqdm(range(start_epoch, args.num_train_epochs), desc="Epoch"):
    #     model.train()
    #     torch.autograd.set_detect_anomaly(True)
    #     # for step in range(median_num_iter):
    #     for step in range(1)
    #         # iterId = startIterID + step + (epochId * median_num_iter)
    #         first_task = True
    #         for task_id in task_ids:
    #             is_forward = False
    #             # if (not task_stop_controller[task_id].in_stop) or (
    #             #     iterId % args.train_iter_gap == 0
    #             # ):
    # args['start_epoch'] = 0
    # args.num_train_epochs
    criterion = nn.BCEWithLogitsLoss()
    target_path = os.path.join(task_cfg[task]["dataroot"],
                               "all_targets_json.json")
    all_targets = json.load(open(target_path, "r"))
    model = model.to(device)
    print(next(model.parameters()).is_cuda)
    for epochId in range(int(start_epoch), int(args.num_train_epochs)):
        model.train()
        is_forward = True

        if is_forward:
            # print("beforeLoop")

            # loss, score = ForwardModelsTrain(
            #     args,
            #     task_cfg,
            #     device,
            #     task_id,
            #     task_count,
            #     task_iter_train,
            #     train_dataset,
            #     model,
            #     task_losses,
            # )

            for step, batch in enumerate(train_loader):
                batch = tuple(
                    t.to(device=device, non_blocking=True) if type(t) ==
                    torch.Tensor else t for t in batch)
                input_ids, input_mask, segment_ids, image_feat, image_loc, image_mask, image_id = (
                    batch)
                true_targets = []
                for id in image_id:
                    true_targets.append(
                        np.fromiter(all_targets[id].values(), dtype=np.double))
                true_targets = torch.from_numpy(np.array(true_targets))
                true_targets = true_targets.to(device)
                model.double()
                model = model.to(device)
                discourse_prediction, vil_prediction, vil_prediction_gqa, vil_logit, vil_binary_prediction, vil_tri_prediction, vision_prediction, vision_logit, linguisic_prediction, linguisic_logit, _ \
                    = model(
                    True,
                    input_ids,
                    image_feat,
                    image_loc,
                    segment_ids,
                    input_mask,
                    image_mask
                )
                loss = criterion(discourse_prediction,
                                 true_targets.type(torch.double))
                loss.backward()
                optimizer.step()
                model.zero_grad()
                print("train train train done")
            #

            print("*********** ITERATION {}  ***********".format(epochId))
            print("*********** TRAIN PERFORMANCE ***********")
            print(loss)
            print(
                compute_score(discourse_prediction.to('cpu'),
                              true_targets.type(torch.float).to('cpu'), 0.5))
            print("*********** TEST PERFORMANCE ***********")
            evaluate(model, device, task_cfg, tokenizer, args, labels)
Пример #9
0
def train(params, args, world_rank):
    logging.info('rank %d, begin data loader init' % world_rank)
    train_data_loader = get_data_loader_distributed(params, world_rank)
    test_data_loader = get_data_loader_distributed_test(params, world_rank)
    logging.info('rank %d, data loader initialized' % world_rank)
    model = UNet.UNet(params).cuda()
    if not args.resuming:
        model.apply(model.get_weights_function(params.weight_init))

    optimizer = optimizers.FusedAdam(model.parameters(), lr=params.lr)
    #model, optimizer = amp.initialize(model, optimizer, opt_level="O1") # for automatic mixed precision
    if params.distributed:
        model = DistributedDataParallel(model)

    iters = 0
    startEpoch = 0
    checkpoint = None
    if args.resuming:
        if world_rank == 0:
            logging.info("Loading checkpoint %s" % params.checkpoint_path)
        checkpoint = torch.load(params.checkpoint_path,
                                map_location='cuda:{}'.format(args.local_rank))
        model.load_state_dict(checkpoint['model_state'])
        iters = checkpoint['iters']
        startEpoch = checkpoint['epoch'] + 1
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if world_rank == 0:
        logging.info(model)
        logging.info("Starting Training Loop...")

    device = torch.cuda.current_device()
    for epoch in range(startEpoch, startEpoch + params.num_epochs):
        start = time.time()
        tr_time = 0.
        log_time = 0.

        for i, data in enumerate(train_data_loader, 0):
            iters += 1
            adjust_LR(optimizer, params, iters)
            inp, tar = map(lambda x: x.to(device), data)
            tr_start = time.time()
            b_size = inp.size(0)

            model.zero_grad()
            gen = model(inp)
            loss = UNet.loss_func(gen, tar, params)

            loss.backward()  # fixed precision

            # automatic mixed precision:
            #with amp.scale_loss(loss, optimizer) as scaled_loss:
            #  scaled_loss.backward()

            optimizer.step()

            tr_end = time.time()
            tr_time += tr_end - tr_start

        # Output training stats
        if world_rank == 0:
            log_start = time.time()
            gens = []
            tars = []
            with torch.no_grad():
                for i, data in enumerate(test_data_loader, 0):
                    if i >= 50:
                        break
                    inp, tar = map(lambda x: x.to(device), data)
                    gen = model(inp)
                    gens.append(gen.detach().cpu().numpy())
                    tars.append(tar.detach().cpu().numpy())
            gens = np.concatenate(gens, axis=0)
            tars = np.concatenate(tars, axis=0)

            # Scalars
            args.tboard_writer.add_scalar('G_loss', loss.item(), iters)

            # Plots
            fig = plot_gens_tars(gens, tars)
            #fig, chi, L1score = meanL1(gens, tars)
            #args.tboard_writer.add_figure('pixhist', fig, iters, close=True)
            #args.tboard_writer.add_scalar('Metrics/chi', chi, iters)
            #args.tboard_writer.add_scalar('Metrics/rhoL1', L1score[0], iters)
            #args.tboard_writer.add_scalar('Metrics/vxL1', L1score[1], iters)
            #args.tboard_writer.add_scalar('Metrics/vyL1', L1score[2], iters)
            #args.tboard_writer.add_scalar('Metrics/vzL1', L1score[3], iters)
            #args.tboard_writer.add_scalar('Metrics/TL1', L1score[4], iters)
            #
            #fig = generate_images(inp.detach().cpu().numpy()[0], gens[-1], tars[-1])
            for figiter in range(5):
                figtag = 'test' + str(figiter)
                args.tboard_writer.add_figure(tag=figtag,
                                              figure=fig[figiter],
                                              close=True)
            #log_end = time.time()
            #log_time += log_end - log_start

            # Save checkpoint
            torch.save(
                {
                    'iters': iters,
                    'epoch': epoch,
                    'model_state': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()
                }, params.checkpoint_path)

        end = time.time()
        if world_rank == 0:
            logging.info('Time taken for epoch {} is {} sec'.format(
                epoch + 1, end - start))
            logging.info('train step time={}, logging time={}'.format(
                tr_time, log_time))
Пример #10
0
def main():

    parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE,
                           iteration_interval=1),
        dllg.JsonBackend(log_file=args.log_file if args.rank == 0 else None,
                         logging_scope=dllg.TRAIN_ITER_SCOPE,
                         iteration_interval=1)
    ])

    LOGGER.timed_block_start("run")
    LOGGER.register_metric(tags.TRAIN_ITERATION_LOSS,
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("iter_time", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("epoch_time", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("run_time", metric_scope=dllg.RUN_SCOPE)
    LOGGER.register_metric("val_iter_loss", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_items/sec",
                           metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_avg_items/sec",
                           metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_avg_loss",
                           metric_scope=dllg.EPOCH_SCOPE)

    log_hardware()

    # Restore training from checkpoint logic
    checkpoint = None
    start_epoch = 0

    model_name = args.model_name
    parser = models.parse_model_args(model_name, parser)
    parser.parse_args()

    args = parser.parse_args()

    log_args(args)

    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    num_gpus = torch.cuda.device_count()
    print("gpus", num_gpus)
    distributed_run = num_gpus > 1
    if distributed_run:
        init_distributed(args, args.world_size, args.rank, args.group_name)

    LOGGER.log(key=tags.RUN_START)
    run_start_time = time.time()

    # Restore training from checkpoint logic
    if args.restore_from:
        print('Restoring from {} checkpoint'.format(args.restore_from))
        checkpoint = torch.load(args.restore_from, map_location='cpu')
        start_epoch = checkpoint['epoch'] + 1
        model_config = checkpoint['config']
        model = models.get_model(model_name, model_config, to_cuda=True)

        new_state_dict = {}
        for key, value in checkpoint['state_dict'].items():
            new_key = key.replace('module.', '')
            new_state_dict[new_key] = value

        model_dict = new_state_dict
        if args.warm_start:
            ignore_layers = ['embedding.weight']
            print('Warm start')

            if len(ignore_layers) > 0:
                model_dict = {
                    k: v
                    for k, v in model_dict.items() if k not in ignore_layers
                }
                dummy_dict = model.state_dict()
                dummy_dict.update(model_dict)
                model_dict = dummy_dict

        model.load_state_dict(model_dict)
    else:
        model_config = models.get_model_config(model_name, args)
        model = models.get_model(model_name, model_config, to_cuda=True)
        print("model configured")
        #model.cuda(4)
    model.cuda()
    # if not args.amp_run and distributed_run:
    #     model = DDP(model ,delay_allreduce=True)
    #
    #

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    # Restore training from checkpoint logic
    if checkpoint and 'optimizer_state_dict' in checkpoint and not args.warm_start:  # TODO: think about this more
        print('Restoring optimizer state')
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if args.amp_run:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        print("amp initialized")

        model = DDP(model, delay_allreduce=True)
        print("ddpmodel")

    try:
        sigma = args.sigma
    except AttributeError:
        sigma = None

    print("train starting")
    criterion = loss_functions.get_loss_function(model_name, sigma)

    try:
        n_frames_per_step = args.n_frames_per_step
    except AttributeError:
        n_frames_per_step = None

    print("data loading start")
    collate_fn = data_functions.get_collate_function(model_name,
                                                     n_frames_per_step)
    trainset = data_functions.get_data_loader(model_name, args.training_files,
                                              args)

    train_sampler = DistributedSampler(trainset) if distributed_run else None
    print("train loader started")

    train_loader = DataLoader(trainset,
                              num_workers=1,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)

    valset = data_functions.get_data_loader(model_name, args.validation_files,
                                            args)

    batch_to_gpu = data_functions.get_batch_to_gpu(model_name)

    iteration = 0
    model.train()

    LOGGER.log(key=tags.TRAIN_LOOP)

    # Restore training from checkpoint logic
    if start_epoch >= args.epochs:
        print('Checkpoint epoch {} >= total epochs {}'.format(
            start_epoch, args.epochs))
    else:
        for epoch in range(start_epoch, args.epochs):
            LOGGER.epoch_start()
            epoch_start_time = time.time()
            LOGGER.log(key=tags.TRAIN_EPOCH_START, value=epoch)

            # used to calculate avg items/sec over epoch
            reduced_num_items_epoch = 0

            # used to calculate avg loss over epoch
            train_epoch_avg_loss = 0.0
            train_epoch_avg_items_per_sec = 0.0
            num_iters = 0

            # if overflow at the last iteration then do not save checkpoint
            overflow = False

            for i, batch in enumerate(train_loader):

                print("Batch: {}/{} epoch {}".format(i, len(train_loader),
                                                     epoch))
                LOGGER.iteration_start()
                iter_start_time = time.time()
                LOGGER.log(key=tags.TRAIN_ITER_START, value=i)

                start = time.perf_counter()
                adjust_learning_rate(epoch, optimizer, args.learning_rate,
                                     args.anneal_steps, args.anneal_factor)

                model.zero_grad()

                x, y, num_items = batch_to_gpu(batch)

                y_pred = model(x)

                loss = criterion(y_pred, y)

                if distributed_run:
                    reduced_loss = reduce_tensor(loss.data,
                                                 args.world_size).item()
                    reduced_num_items = reduce_tensor(num_items.data, 1).item()
                else:
                    reduced_loss = loss.item()
                    reduced_num_items = num_items.item()
                if np.isnan(reduced_loss):
                    raise Exception("loss is NaN")

                LOGGER.log(key=tags.TRAIN_ITERATION_LOSS, value=reduced_loss)

                train_epoch_avg_loss += reduced_loss
                num_iters += 1

                # accumulate number of items processed in this epoch
                reduced_num_items_epoch += reduced_num_items

                if args.amp_run:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.grad_clip_thresh)
                else:
                    loss.backward()
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.grad_clip_thresh)

                optimizer.step()

                iteration += 1

                LOGGER.log(key=tags.TRAIN_ITER_STOP, value=i)

                iter_stop_time = time.time()
                iter_time = iter_stop_time - iter_start_time
                items_per_sec = reduced_num_items / iter_time
                train_epoch_avg_items_per_sec += items_per_sec

                LOGGER.log(key="train_iter_items/sec", value=items_per_sec)
                LOGGER.log(key="iter_time", value=iter_time)
                LOGGER.iteration_stop()

            LOGGER.log(key=tags.TRAIN_EPOCH_STOP, value=epoch)
            epoch_stop_time = time.time()
            epoch_time = epoch_stop_time - epoch_start_time

            LOGGER.log(key="train_epoch_items/sec",
                       value=(reduced_num_items_epoch / epoch_time))
            LOGGER.log(key="train_epoch_avg_items/sec",
                       value=(train_epoch_avg_items_per_sec /
                              num_iters if num_iters > 0 else 0.0))
            LOGGER.log(key="train_epoch_avg_loss",
                       value=(train_epoch_avg_loss /
                              num_iters if num_iters > 0 else 0.0))
            LOGGER.log(key="epoch_time", value=epoch_time)

            LOGGER.log(key=tags.EVAL_START, value=epoch)

            validate(model, criterion, valset, iteration, args.batch_size,
                     args.world_size, collate_fn, distributed_run, args.rank,
                     batch_to_gpu)

            LOGGER.log(key=tags.EVAL_STOP, value=epoch)

            if (epoch % args.epochs_per_checkpoint == 0) and args.rank == 0:
                checkpoint_path = os.path.join(
                    args.output_directory,
                    "checkpoint_{}_{}".format(model_name, epoch))
                save_checkpoint(model, epoch, model_config, optimizer,
                                checkpoint_path)
                save_sample(
                    model_name, model, args.waveglow_checkpoint,
                    args.tacotron2_checkpoint, args.phrase_path,
                    os.path.join(
                        args.output_directory,
                        "sample_{}_{}.wav".format(model_name, iteration)),
                    args.sampling_rate)

            LOGGER.epoch_stop()

    run_stop_time = time.time()
    run_time = run_stop_time - run_start_time
    LOGGER.log(key="run_time", value=run_time)
    LOGGER.log(key=tags.RUN_FINAL)

    print("training time", run_stop_time - run_start_time)

    LOGGER.timed_block_stop("run")

    if args.rank == 0:
        LOGGER.finish()
Пример #11
0
def train_model(model,train_iter,valid_iter,args):

    optimizer = Adam(model.parameters(), lr=args.learning_rate)
    model, optimizer = amp.initialize(model, optimizer)
    #model = DistributedDataParallel(model, device_ids=[args.local_rank])
    model = DistributedDataParallel(model)
    """
    optimizer = Adam([
    {'params': model.bert.parameters(), 'lr': 1e-5},
    {'params': model.classifier.parameters(), 'lr': 3e-4}])
    """
    #监控学习率代码
    #scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=1, verbose=1, min_lr=0.0001)
    #scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,num_training_steps=len(train_iter) // gradient_accumulation_steps * num_epochs)
    scheduler = get_constant_schedule_with_warmup(optimizer,
                                                  num_warmup_steps=args.warmup_steps,)
    total_steps = 0
    best_f1 = 0
    stop_count = 0
    start_time = time.time()
    for epoch in range(args.num_epochs):
        model.train()
        epoch_loss = 0

        for step,input in enumerate(train_iter):
            inputs = {
                "input_ids":input[0].cuda(non_blocking=True),
                "token_type_ids":input[1].cuda(non_blocking=True),
                "attention_mask":input[2].cuda(non_blocking=True)
            }
            total_steps+=1
            labels = input[3].cuda(non_blocking=True)
            logits = model(**inputs)
            #定义损失
            loss = F.cross_entropy(logits,labels)
            #对损失进行包装
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            epoch_loss+=scaled_loss.item()
            #梯度累加
            if total_steps % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                model.zero_grad()
            if total_steps % args.eval_steps ==0:
                end_time = timeSince(start_time)
                print("epoch: {},eval_steps: {},time: {}".format(epoch+1, total_steps, end_time))
                p,r,f1 = evaluate_valid(model, valid_iter)

                print("valid_p:{:.4f},valid_r:{:.4f},valid_f1:{:.4f}".format(p, r, f1))

                if f1 > best_f1:
                    best_f1 = f1
                    # 保存整个模型
                    #torch.save(model, 'resnet.pth')
                    #保存权重
                    torch.save(model.state_dict(), args.save_path)
                    # 释放显存
                    torch.cuda.empty_cache()
                    #model.train()
        #打印epoch_loss
        print('Epoch {} - Loss {:.4f}'.format(epoch + 1, epoch_loss / len(train_iter)))
Пример #12
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--vocab_file",
                        default='bert-base-uncased-vocab.txt',
                        type=str,
                        required=True)
    parser.add_argument("--model_file",
                        default='bert-base-uncased.tar.gz',
                        type=str,
                        required=True)
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )
    parser.add_argument(
        "--predict_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the predictions will be written.")

    # Other parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json"
    )
    parser.add_argument("--test_file", default=None, type=str)
    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=2.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
        "of training.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json "
        "output file.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument(
        "--verbose_logging",
        default=False,
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--view_id',
                        type=int,
                        default=1,
                        help="view id of multi-view co-training(two-view)")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--save_all', default=False, action='store_true')
    parser.add_argument('--max_grad_norm', default=1.0, type=float)
    parser.add_argument('--weight_decay', default=0.0, type=float)
    parser.add_argument('--adam_epsilon', default=1e-8, type=float)
    parser.add_argument('--patience', default=5, type=int)

    # Base setting
    parser.add_argument('--pretrain', type=str, default=None)
    parser.add_argument('--max_ctx', type=int, default=2)
    parser.add_argument('--task_name', type=str, default='race')
    parser.add_argument('--bert_name', type=str, default='pool-race')
    parser.add_argument('--reader_name', type=str, default='race')
    parser.add_argument('--per_eval_step', type=int, default=10000000)
    # model parameters
    parser.add_argument('--evidence_lambda', type=float, default=0.8)
    # Parameters for running labeling model
    parser.add_argument('--do_label', default=False, action='store_true')
    parser.add_argument('--sentence_id_file', nargs='*')
    parser.add_argument('--weight_threshold', type=float, default=0.0)
    parser.add_argument('--only_correct', default=False, action='store_true')
    parser.add_argument('--label_threshold', type=float, default=0.0)
    parser.add_argument('--multi_evidence', default=False, action='store_true')
    parser.add_argument('--metric', default='accuracy', type=str)
    parser.add_argument('--num_evidence', default=1, type=int)
    parser.add_argument('--power_length', default=1., type=float)
    parser.add_argument('--num_choices', default=4, type=int)
    parser.add_argument('--split_type', default=0, type=int)

    args = parser.parse_args()

    logger = setting_logger(args.output_dir)
    logger.info('================== Program start. ========================')

    model_params = prepare_model_params(args)
    read_params = prepare_read_params(args)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict and not args.do_label:
        raise ValueError(
            "At least one of `do_train` or `do_predict` or `do_label` must be True."
        )

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )

    if args.do_train:
        if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
            raise ValueError(
                "Output directory () already exists and is not empty.")
        os.makedirs(args.output_dir, exist_ok=True)
        os.makedirs(os.path.join(args.output_dir, "best_model"), exist_ok=True)
        os.makedirs(os.path.join(args.output_dir, "best_loss_model"),
                    exist_ok=True)

    if args.do_predict or args.do_label:
        os.makedirs(args.predict_dir, exist_ok=True)

    # tokenizer = BertTokenizer.from_pretrained(args.vocab_file)
    tokenizer = get_tokenizer(args.bert_model).from_pretrained(args.vocab_file)

    data_reader = initialize_reader(args.reader_name)

    num_train_steps = None
    if args.do_train or args.do_label:
        train_examples = data_reader.read(input_file=args.train_file,
                                          **read_params)

        cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}_{4}_{5}'.format(
            args.bert_model, str(args.max_seq_length), str(args.doc_stride),
            str(args.max_query_length), str(args.max_ctx), str(args.task_name))

        try:
            with open(cached_train_features_file, "rb") as reader:
                train_features = pickle.load(reader)
        except FileNotFoundError:
            train_features = data_reader.convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length)
            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                logger.info("  Saving train features into cached file %s",
                            cached_train_features_file)
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)

        num_train_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    if args.pretrain is not None:
        logger.info('Load pretrained model from {}'.format(args.pretrain))
        model_state_dict = torch.load(args.pretrain, map_location='cuda:0')
        model = initialize_model(args.bert_name,
                                 args.model_file,
                                 state_dict=model_state_dict,
                                 **model_params)
    else:
        model = initialize_model(args.bert_name, args.model_file,
                                 **model_params)

    # if args.fp16:
    #     model.half()
    model.to(device)

    t_total = num_train_steps if num_train_steps is not None else -1
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=int(args.warmup_proportion *
                                                      t_total),
                                     t_total=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare data
    eval_examples = data_reader.read(input_file=args.predict_file,
                                     **read_params)
    eval_features = data_reader.convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length)

    eval_tensors = data_reader.data_to_tensors(eval_features)
    eval_data = TensorDataset(*eval_tensors)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.predict_batch_size)

    if args.do_train:

        if args.do_label:
            logger.info('Training in State Wise.')
            sentence_label_file = args.sentence_id_file
            if sentence_label_file is not None:
                for file in sentence_label_file:
                    train_features = data_reader.generate_features_sentence_ids(
                        train_features, file)
            else:
                logger.info('No sentence id supervision is found.')
        else:
            logger.info('Training in traditional way.')

        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Num train total optimization steps = %d", t_total)
        logger.info("  Batch size = %d", args.train_batch_size)
        train_loss = AverageMeter()
        best_acc = 0.0
        best_loss = 1000000
        summary_writer = SummaryWriter(log_dir=args.output_dir)
        global_step = 0
        eval_loss = AverageMeter()
        eval_accuracy = CategoricalAccuracy()
        eval_epoch = 0
        last_update = 0

        train_tensors = data_reader.data_to_tensors(train_features)
        train_data = TensorDataset(*train_tensors)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        for epoch in range(int(args.num_train_epochs)):
            logger.info(f'Running at Epoch {epoch}')
            # Train
            for step, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="Iteration",
                         dynamic_ncols=True)):
                model.train()
                if n_gpu == 1:
                    batch = batch_to_device(
                        batch, device)  # multi-gpu does scattering it-self
                inputs = data_reader.generate_inputs(
                    batch, train_features, model_state=ModelState.Train)
                model_output = model(**inputs)
                loss = model_output['loss']
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                    global_step += 1

                    lr_this_step = scheduler.get_lr()[0]
                    summary_writer.add_scalar('lr', lr_this_step, global_step)

                    batch_size = inputs["labels"].size(0)
                    train_loss.update(loss.item() * batch_size, batch_size)
                    summary_writer.add_scalar('train_loss', train_loss.avg,
                                              global_step)

                    if global_step % args.per_eval_step == 0:
                        # Evaluation
                        model.eval()
                        logger.info("Start evaluating")
                        for _, eval_batch in enumerate(
                                tqdm(eval_dataloader,
                                     desc="Evaluating",
                                     dynamic_ncols=True)):
                            if n_gpu == 1:
                                eval_batch = batch_to_device(
                                    eval_batch, device
                                )  # multi-gpu does scattering it-self
                            inputs = data_reader.generate_inputs(
                                eval_batch,
                                eval_features,
                                model_state=ModelState.Evaluate)
                            batch_size = inputs["labels"].size(0)
                            with torch.no_grad():
                                output_dict = model(**inputs)
                                loss, choice_logits = output_dict[
                                    'loss'], output_dict['choice_logits']
                                eval_loss.update(loss.item() * batch_size,
                                                 batch_size)
                                eval_accuracy(choice_logits, inputs["labels"])

                        eval_epoch_loss = eval_loss.avg
                        summary_writer.add_scalar('eval_loss', eval_epoch_loss,
                                                  global_step)
                        eval_loss.reset()
                        current_acc = eval_accuracy.get_metric(reset=True)
                        summary_writer.add_scalar('eval_acc', current_acc,
                                                  global_step)
                        torch.cuda.empty_cache()

                        if args.save_all:
                            model_to_save = model.module if hasattr(
                                model, 'module'
                            ) else model  # Only save the model it-self
                            output_model_file = os.path.join(
                                args.output_dir,
                                f"checkpoint-{global_step}.bin")
                            model_to_save.save_pretrained(output_model_file)
                            # torch.save(model_to_save.state_dict(), output_model_file)

                        if current_acc > best_acc:
                            best_acc = current_acc
                            model_to_save = model.module if hasattr(
                                model, 'module'
                            ) else model  # Only save the model it-self
                            output_model_file = os.path.join(
                                args.output_dir, "best_model")
                            model_to_save.save_pretrained(output_model_file)
                            # torch.save(model_to_save.state_dict(), output_model_file)
                            last_update = global_step // args.per_eval_step
                        if eval_epoch_loss < best_loss:
                            best_loss = eval_epoch_loss
                            model_to_save = model.module if hasattr(
                                model, 'module'
                            ) else model  # Only save the model it-self
                            output_model_file = os.path.join(
                                args.output_dir, "best_loss_model")
                            model_to_save.save_pretrained(output_model_file)
                            # torch.save(model_to_save.state_dict(), output_model_file)

                        logger.info(
                            'Global Step: %d, Accuracy: %.4f (Best Accuracy: %.4f)'
                            % (global_step, current_acc, best_acc))
                        eval_epoch += 1

                        if global_step // args.per_eval_step - last_update >= args.patience:
                            logger.info(
                                f"Training reach patience: {args.patience}, training stopped."
                            )
                            break

            if global_step // args.per_eval_step - last_update >= args.patience:
                break

            logger.info(
                f'Epoch {epoch}: Accuracy: {best_acc}, Train Loss: {train_loss.avg}'
            )
        summary_writer.close()

    for output_model_name in ["best_model", "best_loss_model"]:
        # Loading trained model
        output_model_file = os.path.join(args.output_dir, output_model_name)
        # model_state_dict = torch.load(output_model_file, map_location='cuda:0')
        # model = initialize_model(args.bert_name, args.model_file, state_dict=model_state_dict, **model_params)
        model = initialize_model(args.bert_name, output_model_file,
                                 **model_params)
        model.to(device)

        # Write Yes/No predictions
        if args.do_predict and (args.local_rank == -1
                                or torch.distributed.get_rank() == 0):

            test_examples = data_reader.read(args.test_file)
            test_features = data_reader.convert_examples_to_features(
                test_examples, tokenizer, args.max_seq_length)

            test_tensors = data_reader.data_to_tensors(test_features)
            test_data = TensorDataset(*test_tensors)
            test_sampler = SequentialSampler(test_data)
            test_dataloader = DataLoader(test_data,
                                         sampler=test_sampler,
                                         batch_size=args.predict_batch_size)

            logger.info("***** Running predictions *****")
            logger.info("  Num orig examples = %d", len(test_examples))
            logger.info("  Num split examples = %d", len(test_features))
            logger.info("  Batch size = %d", args.predict_batch_size)

            model.eval()
            all_results = []
            test_acc = CategoricalAccuracy()
            logger.info("Start predicting yes/no on Dev set.")
            for batch in tqdm(test_dataloader, desc="Testing"):
                if n_gpu == 1:
                    batch = batch_to_device(
                        batch, device)  # multi-gpu does scattering it-self
                inputs = data_reader.generate_inputs(
                    batch, test_features, model_state=ModelState.Evaluate)
                with torch.no_grad():
                    batch_choice_logits = model(**inputs)['choice_logits']
                    test_acc(batch_choice_logits, inputs['labels'])
                example_indices = batch[-1]
                for i, example_index in enumerate(example_indices):
                    choice_logits = batch_choice_logits[i].detach().cpu(
                    ).tolist()

                    test_feature = test_features[example_index.item()]
                    unique_id = int(test_feature.unique_id)

                    all_results.append(
                        RawResultChoice(unique_id=unique_id,
                                        choice_logits=choice_logits))

            if "loss" in output_model_name:
                logger.info(
                    'Predicting question choice on test set using model with lowest loss on validation set.'
                )
                output_prediction_file = os.path.join(args.predict_dir,
                                                      'loss_predictions.json')
            else:
                logger.info(
                    'Predicting question choice on test set using model with best accuracy on validation set,'
                )
                output_prediction_file = os.path.join(args.predict_dir,
                                                      'predictions.json')
            data_reader.write_predictions(test_examples, test_features,
                                          all_results, output_prediction_file)
            logger.info(
                f"Accuracy on Test set: {test_acc.get_metric(reset=True)}")

    # Loading trained model.
    if args.metric == 'accuracy':
        logger.info("Load model with best accuracy on validation set.")
        output_model_file = os.path.join(args.output_dir, "best_model")
    elif args.metric == 'loss':
        logger.info("Load model with lowest loss on validation set.")
        output_model_file = os.path.join(args.output_dir, "best_loss_model")
    else:
        raise RuntimeError(
            f"Wrong metric type for {args.metric}, which must be in ['accuracy', 'loss']."
        )
    # model_state_dict = torch.load(output_model_file, map_location='cuda:0')
    # model = initialize_model(args.bert_name, args.model_file, state_dict=model_state_dict, **model_params)
    model = initialize_model(args.bert_name, output_model_file, **model_params)
    model.to(device)

    # Labeling sentence id.
    if args.do_label and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):

        f = open('debug_log.txt', 'w')

        def softmax(x):
            """Compute softmax values for each sets of scores in x."""
            e_x = np.exp(x - np.max(x))
            return e_x / e_x.sum()

        def topk(sentence_sim):
            """
            :param sentence_sim: numpy
            :return:
            """
            max_length = min(args.num_evidence, len(sentence_sim))
            sorted_scores = np.array(sorted(sentence_sim, reverse=True))
            scores = []
            for idx in range(max_length):
                scores.append(np.log(softmax(sorted_scores[idx:])[0]))
            scores = [np.mean(scores[:(j + 1)]) for j in range(max_length)]
            top_k = int(np.argmax(scores) + 1)
            sorted_scores = sorted(enumerate(sentence_sim),
                                   key=lambda x: x[1],
                                   reverse=True)
            evidence_ids = [x[0] for x in sorted_scores[:top_k]]
            sentence = {
                'sentences': evidence_ids,
                'value': float(np.exp(scores[top_k - 1]))
            }
            return sentence

        def batch_topk(sentence_sim, sentence_mask):
            batch_size = sentence_sim.size(0)
            num_choices = sentence_sim.size(1)
            sentence_sim = sentence_sim.numpy() + 1e-15
            sentence_mask = sentence_mask.numpy()
            sentence_ids = []
            for b in range(batch_size):
                choice_sentence_ids = [
                    topk(_sim[:int(sum(_mask))])
                    for _sim, _mask in zip(sentence_sim[b], sentence_mask[b])
                ]
                assert len(choice_sentence_ids) == num_choices
                sentence_ids.append(choice_sentence_ids)
            return sentence_ids

        test_examples = train_examples
        test_features = train_features

        test_tensors = data_reader.data_to_tensors(test_features)
        test_data = TensorDataset(*test_tensors)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.predict_batch_size)

        logger.info("***** Running labeling *****")
        logger.info("  Num orig examples = %d", len(test_examples))
        logger.info("  Num split examples = %d", len(test_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start labeling.")
        for batch in tqdm(test_dataloader, desc="Testing"):
            if n_gpu == 1:
                batch = batch_to_device(batch, device)
            inputs = data_reader.generate_inputs(batch,
                                                 test_features,
                                                 model_state=ModelState.Test)
            with torch.no_grad():
                output_dict = model(**inputs)
                batch_choice_logits, batch_sentence_logits = output_dict[
                    "choice_logits"], output_dict["sentence_logits"]
                batch_sentence_mask = output_dict["sentence_mask"]
            example_indices = batch[-1]
            # batch_beam_results = batch_choice_beam_search(batch_sentence_logits, batch_sentence_mask)
            batch_topk_results = batch_topk(batch_sentence_logits,
                                            batch_sentence_mask)
            for i, example_index in enumerate(example_indices):
                choice_logits = batch_choice_logits[i].detach().cpu()
                evidence_list = batch_topk_results[i]

                test_feature = test_features[example_index.item()]
                unique_id = int(test_feature.unique_id)

                all_results.append(
                    RawOutput(unique_id=unique_id,
                              model_output={
                                  "choice_logits": choice_logits,
                                  "evidence_list": evidence_list
                              }))

        output_prediction_file = os.path.join(args.predict_dir,
                                              'sentence_id_file.json')
        data_reader.predict_sentence_ids(
            test_examples,
            test_features,
            all_results,
            output_prediction_file,
            weight_threshold=args.weight_threshold,
            only_correct=args.only_correct,
            label_threshold=args.label_threshold)
Пример #13
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--vocab_file",
                        default='bert-base-uncased-vocab.txt',
                        type=str,
                        required=True)
    parser.add_argument("--model_file",
                        default='bert-base-uncased.tar.gz',
                        type=str,
                        required=True)
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )
    parser.add_argument(
        "--predict_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the predictions will be written.")

    # Other parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json"
    )
    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=2.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
        "of training.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json "
        "output file.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument(
        "--verbose_logging",
        default=False,
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--view_id',
                        type=int,
                        default=1,
                        help="view id of multi-view co-training(two-view)")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--adam_epsilon', type=float, default=1e-8)
    parser.add_argument('--max_grad_norm', type=float, default=1.0)

    # Base setting
    parser.add_argument('--pretrain', type=str, default=None)
    parser.add_argument('--max_ctx', type=int, default=2)
    parser.add_argument('--task_name', type=str, default='coqa_yesno')
    parser.add_argument('--bert_name', type=str, default='baseline')
    parser.add_argument('--reader_name', type=str, default='coqa')
    # model parameters
    parser.add_argument('--evidence_lambda', type=float, default=0.8)
    parser.add_argument('--tf_layers', type=int, default=1)
    parser.add_argument('--tf_inter_size', type=int, default=3072)
    # Parameters for running labeling model
    parser.add_argument('--do_label', default=False, action='store_true')
    parser.add_argument('--sentence_id_files', nargs='*')
    parser.add_argument('--weight_threshold', type=float, default=0.0)
    parser.add_argument('--only_correct', default=False, action='store_true')
    parser.add_argument('--label_threshold', type=float, default=0.0)

    args = parser.parse_args()

    logger = setting_logger(args.output_dir)
    logger.info('================== Program start. ========================')

    model_params = prepare_model_params(args)
    read_params = prepare_read_params(args)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )

    if args.do_train:
        if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
            raise ValueError(
                "Output directory () already exists and is not empty.")
        os.makedirs(args.output_dir, exist_ok=True)

    if args.do_predict:
        os.makedirs(args.predict_dir, exist_ok=True)

    tokenizer = BertTokenizer.from_pretrained(args.vocab_file)

    data_reader = initialize_reader(args.reader_name)

    num_train_steps = None
    if args.do_train or args.do_label:
        train_examples = data_reader.read(input_file=args.train_file,
                                          **read_params)

        cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}_{4}_{5}'.format(
            args.bert_model, str(args.max_seq_length), str(args.doc_stride),
            str(args.max_query_length), str(args.max_ctx), str(args.task_name))

        try:
            with open(cached_train_features_file, "rb") as reader:
                train_features = pickle.load(reader)
        except FileNotFoundError:
            train_features = data_reader.convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=True)
            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                logger.info("  Saving train features into cached file %s",
                            cached_train_features_file)
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)
        print(train_features[-1].unique_id)
        num_train_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    if args.pretrain is not None:
        logger.info('Load pretrained model from {}'.format(args.pretrain))
        model_state_dict = torch.load(args.pretrain, map_location='cuda:0')
        model = initialize_model(args.bert_name,
                                 args.model_file,
                                 state_dict=model_state_dict,
                                 **model_params)
    else:
        model = initialize_model(args.bert_name, args.model_file,
                                 **model_params)

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    t_total = num_train_steps
    # if args.local_rank != -1:
    #     t_total = t_total // torch.distributed.get_world_size()
    # if args.fp16:
    #     try:
    #         from apex.optimizers import FP16_Optimizer
    #         from apex.optimizers import FusedAdam
    #     except ImportError:
    #         raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
    #
    #     optimizer = FusedAdam(optimizer_grouped_parameters,
    #                           lr=args.learning_rate,
    #                           bias_correction=False,
    #                           max_grad_norm=1.0)
    #     if args.loss_scale == 0:
    #         optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
    #     else:
    #         optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
    # else:
    #     optimizer = BertAdam(optimizer_grouped_parameters,
    #                          lr=args.learning_rate,
    #                          warmup=args.warmup_proportion,
    #                          t_total=t_total)
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_proportion *
                                     t_total,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # Prepare data
    eval_examples = data_reader.read(input_file=args.predict_file,
                                     **read_params)
    eval_features = data_reader.convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        doc_stride=args.doc_stride,
        max_query_length=args.max_query_length,
        is_training=False)

    eval_tensors = data_reader.data_to_tensors(eval_features)
    eval_data = TensorDataset(*eval_tensors)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.predict_batch_size)

    if args.do_train:

        if args.do_label:
            logger.info('Training in State Wise.')
            sentence_id_file_list = args.sentence_id_files
            if sentence_id_file_list is not None:
                for file in sentence_id_file_list:
                    train_features = data_reader.generate_features_sentence_ids(
                        train_features, file)
            else:
                train_features = data_reader.mask_all_sentence_ids(
                    train_features)
                logger.info('No sentence id supervision is found.')
        else:
            logger.info('Training in traditional way.')

        logger.info("Start training")
        train_loss = AverageMeter()
        best_acc = 0.0
        summary_writer = SummaryWriter(log_dir=args.output_dir)
        global_step = 0
        eval_loss = AverageMeter()

        train_tensors = data_reader.data_to_tensors(train_features)
        train_data = TensorDataset(*train_tensors)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            # Train
            model.train()
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                if n_gpu == 1:
                    batch = batch_to_device(
                        batch, device)  # multi-gpu does scattering it-self
                inputs = data_reader.generate_inputs(
                    batch,
                    train_features,
                    do_label=args.do_label,
                    model_state=ModelState.Train)
                loss = model(**inputs)['loss']
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    # optimizer.backward(loss)
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    scheduler.step()
                    optimizer.step()
                    model.zero_grad()
                    global_step += 1

                    # # modify learning rate with special warm up BERT uses
                    # """ 19.7.4: Fix learning rate bug. """
                    # if args.fp16:
                    #     """ 19.7.4 warmup_linear is used as the function in optimization not as the comment above. """
                    #     lr_this_step = args.learning_rate * warmup_linear(global_step / t_total, args.warmup_proportion)
                    #     for param_group in optimizer.param_groups:
                    #         param_group['lr'] = lr_this_step
                    #     summary_writer.add_scalar('lr', lr_this_step, global_step)
                    # else:
                    #     summary_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
                    #
                    # optimizer.step()
                    # optimizer.zero_grad()
                    # global_step += 1

                train_loss.update(loss.item(), args.train_batch_size)
                summary_writer.add_scalar('train_loss', train_loss.avg,
                                          global_step)
                summary_writer.add_scalar('lr',
                                          scheduler.get_lr()[0], global_step)

            # Evaluation
            model.eval()
            all_results = []
            logger.info("Start evaluating")
            for eval_step, batch in enumerate(
                    tqdm(eval_dataloader, desc="Evaluating")):
                if n_gpu == 1:
                    batch = batch_to_device(
                        batch, device)  # multi-gpu does scattering it-self
                inputs = data_reader.generate_inputs(
                    batch,
                    eval_features,
                    do_label=args.do_label,
                    model_state=ModelState.Evaluate)
                with torch.no_grad():
                    output_dict = model(**inputs)
                    loss, batch_choice_logits = output_dict[
                        'loss'], output_dict['yesno_logits']
                    eval_loss.update(loss.item(), args.predict_batch_size)
                    summary_writer.add_scalar(
                        'eval_loss', eval_loss.avg,
                        epoch * len(eval_dataloader) + eval_step)
                example_indices = batch[-1]
                for i, example_index in enumerate(example_indices):
                    choice_logits = batch_choice_logits[i].detach().cpu(
                    ).tolist()

                    eval_feature = eval_features[example_index.item()]
                    unique_id = int(eval_feature.unique_id)
                    all_results.append(
                        RawResultChoice(unique_id=unique_id,
                                        choice_logits=choice_logits))

            data_reader.write_predictions(eval_examples,
                                          eval_features,
                                          all_results,
                                          None,
                                          null_score_diff_threshold=0.0)
            yes_metric = data_reader.yesno_cate.f1_measure('yes', 'no')
            no_metric = data_reader.yesno_cate.f1_measure('no', 'yes')
            current_acc = yes_metric['accuracy']
            summary_writer.add_scalar('eval_yes_f1', yes_metric['f1'], epoch)
            summary_writer.add_scalar('eval_yes_recall', yes_metric['recall'],
                                      epoch)
            summary_writer.add_scalar('eval_yes_precision',
                                      yes_metric['precision'], epoch)
            summary_writer.add_scalar('eval_no_f1', no_metric['f1'], epoch)
            summary_writer.add_scalar('eval_no_recall', no_metric['recall'],
                                      epoch)
            summary_writer.add_scalar('eval_no_precision',
                                      no_metric['precision'], epoch)
            summary_writer.add_scalar('eval_yesno_acc', current_acc, epoch)
            torch.cuda.empty_cache()

            if current_acc > best_acc:
                best_acc = current_acc
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(args.output_dir,
                                                 "pytorch_model.bin")
                torch.save(model_to_save.state_dict(), output_model_file)
            logger.info('Epoch: %d, Accuracy: %f (Best Accuracy: %f)' %
                        (epoch, current_acc, best_acc))
            data_reader.yesno_cate.reset()

        summary_writer.close()

    # Loading trained model.
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    model_state_dict = torch.load(output_model_file, map_location='cuda:0')
    model = initialize_model(args.bert_name,
                             args.model_file,
                             state_dict=model_state_dict,
                             **model_params)
    model.to(device)

    # Write Yes/No predictions
    if args.do_predict and (args.local_rank == -1
                            or torch.distributed.get_rank() == 0):

        test_examples = eval_examples
        test_features = eval_features

        test_tensors = data_reader.data_to_tensors(test_features)
        test_data = TensorDataset(*test_tensors)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.predict_batch_size)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(test_examples))
        logger.info("  Num split examples = %d", len(test_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start predicting yes/no on Dev set.")
        for batch in tqdm(test_dataloader, desc="Testing"):
            if n_gpu == 1:
                batch = batch_to_device(
                    batch, device)  # multi-gpu does scattering it-self
            inputs = data_reader.generate_inputs(batch,
                                                 test_features,
                                                 do_label=args.do_label,
                                                 model_state=ModelState.Test)
            with torch.no_grad():
                batch_choice_logits = model(**inputs)['yesno_logits']
            example_indices = batch[-1]
            for i, example_index in enumerate(example_indices):
                choice_logits = batch_choice_logits[i].detach().cpu().tolist()

                test_feature = test_features[example_index.item()]
                unique_id = int(test_feature.unique_id)

                all_results.append(
                    RawResultChoice(unique_id=unique_id,
                                    choice_logits=choice_logits))

        output_prediction_file = os.path.join(args.predict_dir,
                                              'predictions.json')
        data_reader.write_predictions(eval_examples,
                                      eval_features,
                                      all_results,
                                      output_prediction_file,
                                      null_score_diff_threshold=0.0)
        yes_metric = data_reader.yesno_cate.f1_measure('yes', 'no')
        no_metric = data_reader.yesno_cate.f1_measure('no', 'yes')
        logger.info('Yes Metrics: %s' % json.dumps(yes_metric, indent=2))
        logger.info('No Metrics: %s' % json.dumps(no_metric, indent=2))

    # Labeling sentence id.
    if args.do_label and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):

        test_examples = train_examples
        test_features = train_features

        test_tensors = data_reader.data_to_tensors(test_features)
        test_data = TensorDataset(*test_tensors)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.predict_batch_size)

        logger.info("***** Running labeling *****")
        logger.info("  Num orig examples = %d", len(test_examples))
        logger.info("  Num split examples = %d", len(test_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start labeling.")
        for batch in tqdm(test_dataloader, desc="Testing"):
            if n_gpu == 1:
                batch = batch_to_device(batch, device)
            inputs = data_reader.generate_inputs(batch,
                                                 test_features,
                                                 do_label=args.do_label,
                                                 model_state=ModelState.Test)
            with torch.no_grad():
                output_dict = model(**inputs)
                batch_choice_logits = output_dict['yesno_logits']
                batch_max_weight_indexes = output_dict['max_weight_index']
                batch_max_weight = output_dict['max_weight']
            example_indices = batch[-1]
            for i, example_index in enumerate(example_indices):
                choice_logits = batch_choice_logits[i].detach().cpu().tolist()
                max_weight_index = batch_max_weight_indexes[i].detach().cpu(
                ).tolist()
                max_weight = batch_max_weight[i].detach().cpu().tolist()

                test_feature = test_features[example_index.item()]
                unique_id = int(test_feature.unique_id)

                all_results.append(
                    WeightResultChoice(unique_id=unique_id,
                                       choice_logits=choice_logits,
                                       max_weight_index=max_weight_index,
                                       max_weight=max_weight))

        output_prediction_file = os.path.join(args.predict_dir,
                                              'sentence_id_file.json')
        data_reader.predict_sentence_ids(
            test_examples,
            test_features,
            all_results,
            output_prediction_file,
            weight_threshold=args.weight_threshold,
            only_correct=args.only_correct,
            label_threshold=args.label_threshold)
def main():
    parser = ArgumentParser()
    parser.add_argument('--pregenerated_data', type=Path, required=True)
    parser.add_argument('--output_dir', type=str, required=True)
    parser.add_argument("--bert_model", type=str, required=True, help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument("--reduce_memory", action="store_true",
                        help="Store training data as on-disc memmaps to massively reduce memory usage")

    parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=4,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                        "0 (default value): dynamic loss scaling.\n"
                        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--warmup_steps", 
                        default=0, 
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--adam_epsilon", 
                        default=1e-8, 
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    
    args = parser.parse_args()

    assert args.pregenerated_data.is_dir(), \
        "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"

    samples_per_epoch = []
    for i in range(args.epochs):
        epoch_file = args.pregenerated_data / f"epoch_{i}.json"
        metrics_file = args.pregenerated_data / f"epoch_{i}_metrics.json"
        if epoch_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch.append(metrics['num_training_examples'])
        else:
            if i == 0:
                exit("No training data was found!")
            print(f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs}).")
            print("This script will loop over the available data, but training diversity may be negatively impacted.")
            num_data_epochs = i
            break
    else:
        num_data_epochs = args.epochs

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logging.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

        # Create output directory if needed
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    total_train_examples = 0
    for i in range(args.epochs):
        # The modulo takes into account the fact that we may loop over limited epochs of data
        total_train_examples += samples_per_epoch[i % len(samples_per_epoch)]

    num_train_optimization_steps = int(
        total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)
    if args.local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
        

    # Prepare model
    model = BertForPreTraining.from_pretrained(args.bert_model)
    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    
    
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)
    model.to(device)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)
    
    global_step = 0
    logging.info("***** Running training *****")
    logging.info(f"  Num examples = {total_train_examples}")
    logging.info("  Batch size = %d", args.train_batch_size)
    logging.info("  Num steps = %d", num_train_optimization_steps)
    for epoch in range(args.epochs):
        model.train()
        epoch_dataset = PregeneratedDataset(epoch=epoch, training_path=args.pregenerated_data, tokenizer=tokenizer,
                                            num_data_epochs=num_data_epochs, reduce_memory=args.reduce_memory)
        if args.local_rank == -1:
            train_sampler = RandomSampler(epoch_dataset)
        else:
            train_sampler = DistributedSampler(epoch_dataset)
        train_dataloader = DataLoader(epoch_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar:
            for step, batch in enumerate(train_dataloader):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
                outputs = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
                loss = outputs[0]
                
                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu parallel training
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    scheduler.step()  # Update learning rate schedule
                    optimizer.step()
                    model.zero_grad()
                    global_step += 1

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                pbar.update(1)
                mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
                pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")

            logging.info("** ** * Saving fine-tuned model ** ** * ")
            output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self

            # If we save using the predefined names, we can load using `from_pretrained`
            output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
            output_config_file = os.path.join(output_dir, CONFIG_NAME)

            torch.save(model_to_save.state_dict(), output_model_file)
            model_to_save.config.to_json_file(output_config_file)
            tokenizer.save_vocabulary(output_dir)
Пример #15
0
def benchmark_training(model, opts):
    """Benchmarks training phase.

    :param obj model: A model to benchmark
    :param dict opts: A dictionary of parameters.
    :rtype: tuple:
    :return: A tuple of (model_name, list of batch times)
    """
    def _reduce_tensor(tensor):
        reduced = tensor.clone()
        dist.all_reduce(reduced, op=dist.reduce_op.SUM)
        reduced /= opts['world_size']
        return reduced

    if opts['phase'] != 'training':
        raise "Phase in benchmark_training func is '%s'" % opts['phase']

    opts['distributed'] = opts['world_size'] > 1
    opts['with_cuda'] = opts['device'] == 'gpu'
    opts['fp16'] = opts['dtype'] == 'float16'
    opts['loss_scale'] = 1

    if opts['fp16'] and not opts['with_cuda']:
        raise ValueError(
            "Configuration error: FP16 can only be used with GPUs")

    if opts['with_cuda']:
        torch.cuda.set_device(opts['local_rank'])
        cudnn.benchmark = opts['cudnn_benchmark']
        cudnn.fastest = opts['cudnn_fastest']

    if opts['distributed']:
        dist.init_process_group(backend=opts['dist_backend'],
                                init_method='env://')

    if opts['with_cuda']:
        model = model.cuda()
        if opts['dtype'] == 'float16':
            model = network_to_half(model)

    if opts['distributed']:
        model = DDP(model, shared_param=True)

    if opts['fp16']:
        model_params, master_params = prep_param_lists(model)
    else:
        master_params = list(model.parameters())

    criterion = nn.CrossEntropyLoss()
    if opts['with_cuda']:
        criterion = criterion.cuda()
    optimizer = optim.SGD(master_params,
                          lr=0.01,
                          momentum=0.9,
                          weight_decay=1e-4)

    data_loader = DatasetFactory.get_data_loader(opts, opts['__input_shape'],
                                                 opts['__num_classes'])

    is_warmup = opts['num_warmup_batches'] > 0
    done = opts['num_warmup_batches'] == 0
    num_iterations_done = 0
    model.train()
    batch_times = np.zeros(opts['num_batches'])
    end_time = timeit.default_timer()
    while not done:
        prefetcher = DataPrefetcher(data_loader, opts)
        batch_data, batch_labels = prefetcher.next()
        while batch_data is not None:
            data_var = torch.autograd.Variable(batch_data)
            labels_var = torch.autograd.Variable(batch_labels)

            output = model(data_var)

            loss = criterion(output, labels_var)
            loss = loss * opts['loss_scale']
            # I'll need this for reporting
            #reduced_loss = _reduce_tensor(loss.data) if opts['distributed'] else loss.data

            if opts['fp16']:
                model.zero_grad()
                loss.backward()
                model_grads_to_master_grads(model_params, master_params)
                if opts['loss_scale'] != 1:
                    for param in master_params:
                        param.grad.data = param.grad.data / opts['loss_scale']
                optimizer.step()
                master_params_to_model_params(model_params, master_params)
            else:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if opts['with_cuda']:
                torch.cuda.synchronize()

            # Track progress
            num_iterations_done += 1
            cur_time = timeit.default_timer()

            batch_data, batch_labels = prefetcher.next()

            if is_warmup:
                if num_iterations_done >= opts['num_warmup_batches']:
                    is_warmup = False
                    num_iterations_done = 0
            else:
                if opts['num_batches'] != 0:
                    batch_times[num_iterations_done - 1] = cur_time - end_time
                if num_iterations_done >= opts['num_batches']:
                    done = True
                    break
            end_time = cur_time

    return (opts['__name'], batch_times)
Пример #16
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--training_data_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The training data path")
    parser.add_argument("--validation_data_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The validation data path")

    parser.add_argument(
        "--mcq_model",
        default=None,
        type=str,
        required=True,
        help="choose one from the list: bert-mcq-parallel-max, "
        "bert-mcq_parallel-weighted-sum, bert-mcq-concat, mac-bert, or add roberta instead of bert"
    )

    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese, roberta-base, roberta-large"
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--max_grad_norm",
                        default=None,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--dropout", default=0.0, type=float, help="dropout")
    parser.add_argument(
        "--eval_freq",
        default=0,
        type=int,
        help="Evaluation steps frequency. Default is at the end of each epoch. "
        "You can also increase the frequency")
    parser.add_argument(
        '--tie_weights_weighted_sum',
        action='store_true',
        help="Whether to tie the weights for the weighted sum model")
    parser.add_argument('--max_number_premises',
                        type=int,
                        default=None,
                        help="Number of premise sentences to use at max")
    parser.add_argument('--num_labels',
                        type=int,
                        default=3,
                        help="Number of labels")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument('--with_score',
                        action='store_true',
                        help="Knowledge with score is provided")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    # true batch size
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")


#     if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
#         raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and not args.overwrite_output_dir:
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
            .format(args.output_dir))

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    with open(os.path.join(args.output_dir, "mcq_inputs.json"), 'w') as f:
        json.dump(vars(args), f, indent=2)

    stdout_handler = prepare_global_logging(args.output_dir, False)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if "roberta" in args.bert_model:
        tokenizer = RobertaTokenizer.from_pretrained(
            "roberta-large", do_lower_case=args.do_lower_case)
        logger.info("Type of Tokenizer : ROBERTA")
    else:
        tokenizer = BertTokenizer.from_pretrained(
            args.bert_model, do_lower_case=args.do_lower_case)
        logger.info("Type of Tokenizer : BERT")

    data_reader = None
    if args.mcq_model == 'bert-mcq-parallel-max':
        model = BertMCQParallel.from_pretrained(
            args.bert_model,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(args.local_rank)))
        data_reader = BertMCQParallelReader()
    elif args.mcq_model == 'bert-mcq-concat':
        model = BertMCQConcat.from_pretrained(
            args.bert_model,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(args.local_rank)))
        data_reader = BertMCQConcatReader()
    elif args.mcq_model == 'bert-mcq-weighted-sum':
        model = BertMCQWeightedSum.from_pretrained(
            args.bert_model,
            tie_weights=args.tie_weights_weighted_sum,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(args.local_rank)))
        data_reader = BertMCQParallelReader()
    elif args.mcq_model == 'bert-mcq-simple-sum':
        model = BertMCQSimpleSum.from_pretrained(
            args.bert_model,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(args.local_rank)))
        data_reader = BertMCQParallelReader()
    elif args.mcq_model == 'bert-mcq-mac':
        model = BertMCQMAC.from_pretrained(
            args.bert_model,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(args.local_rank)))
        data_reader = BertMCQParallelReader()
    elif args.mcq_model == 'roberta-mcq-parallel-max':
        model = RoBertaMCQParallel.from_pretrained(
            args.bert_model,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(args.local_rank)))
        data_reader = RoBertaMCQParallelReader()
    elif args.mcq_model == 'roberta-mcq-concat':
        model = RoBertaMCQConcat.from_pretrained(
            args.bert_model,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(args.local_rank)))
        data_reader = RoBertaMCQConcatReader()

    elif args.mcq_model == 'roberta-mcq-weighted-sum':
        model = RoBertaMCQWeightedSum.from_pretrained(
            args.bert_model,
            tie_weights=args.tie_weights_weighted_sum,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(args.local_rank)))
        data_reader = RoBertaMCQParallelReader()

    elif args.mcq_model == 'roberta-mcq-ws-score':
        model = RoBertaMCQWeightedSumScore.from_pretrained(
            args.bert_model,
            tie_weights=args.tie_weights_weighted_sum,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(args.local_rank)))
        data_reader = RoBertaMCQParallelScoreReader()

    elif args.mcq_model == 'roberta-mcq-simple-sum':
        model = RoBertaMCQSimpleSum.from_pretrained(
            args.bert_model,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(args.local_rank)))
        data_reader = RoBertaMCQParallelReader()

    elif args.mcq_model == 'roberta-mcq-ss-score':
        model = RoBertaMCQSimpleSumScore.from_pretrained(
            args.bert_model,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(args.local_rank)))
        data_reader = RoBertaMCQParallelScoreReader()

    elif args.mcq_model == 'roberta-mcq-mac':
        model = RoBertaMCQMAC.from_pretrained(
            args.bert_model,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(args.local_rank)))
        data_reader = RoBertaMCQParallelReader()
    elif args.mcq_model == 'roberta-mcq-conv3d':
        model = RoBertaMCQConv3d.from_pretrained(
            args.bert_model,
            cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE),
                                   'distributed_{}'.format(args.local_rank)))
        data_reader = RoBertaMCQParallelReader()
    else:
        logger.error(f"Invalid MCQ model name {args.mcq_model}")
        exit(0)

    if args.do_train:
        # Prepare data loader
        # get data loader for train/dev
        train_data = data_reader.read(args.training_data_path, tokenizer,
                                      args.max_seq_length,
                                      args.max_number_premises)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        eval_data = data_reader.read(args.validation_data_path, tokenizer,
                                     args.max_seq_length,
                                     args.max_number_premises)
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        # num_train_optimization_steps, dividing by effective batch size
        t_total = (len(train_dataloader) //
                   args.gradient_accumulation_steps) * args.num_train_epochs

        num_train_optimization_steps = (
            len(train_dataloader) //
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

        # Prepare optimizer
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay
        }, {
            'params': [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)
        model.to(device)
        if args.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.fp16_opt_level)

        if args.local_rank != -1:
            try:
                from apex.parallel import DistributedDataParallel as DDP
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            model = DDP(model)
        elif n_gpu > 1 and not args.no_cuda:
            model = torch.nn.DataParallel(model)

        global_step = 0
        number_of_batches_per_epoch = len(train_dataloader)
        if args.eval_freq > 0:
            steps_to_eval = args.eval_freq

        else:
            steps_to_eval = number_of_batches_per_epoch

        logger.info("***** Training *****")
        logger.info("  num examples = %d", len(train_data))
        logger.info("  batch size = %d", args.train_batch_size)
        logger.info("  num steps = %d", num_train_optimization_steps)
        logger.info("  number of Gpus= %d", n_gpu)
        logger.info("***** Evaluation *****")
        logger.info("  num examples = %d", len(eval_data))
        logger.info("  batch size = %d", args.eval_batch_size)

        best_acc = 0.0
        best_epoch = 1

        for epoch_index in trange(int(args.num_train_epochs), desc="Epoch"):
            epoch_start_time = time.time()
            model.train()
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            tq = tqdm(train_dataloader, desc="Iteration")
            acc = 0
            for step, batch in enumerate(tq):
                batch = tuple(t.to(device) for t in batch)
                if not args.with_score:
                    input_ids, segment_ids, input_mask, label_ids = batch
                    outputs = model(input_ids, segment_ids, input_mask,
                                    label_ids)
                else:
                    input_ids, segment_ids, input_mask, scores, label_ids = batch
                    outputs = model(input_ids, segment_ids, input_mask, scores,
                                    label_ids)
                loss = outputs[0]
                logits = outputs[1]
                logits = logits.detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()
                tmp_accuracy = accuracy(logits, label_ids)
                acc += tmp_accuracy

                if n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    if args.max_grad_norm is not None:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    if args.max_grad_norm is not None:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    scheduler.step()  # Update learning rate schedule
                    optimizer.step()
                    model.zero_grad()
                    global_step += 1

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                tq.set_description(
                    _get_loss_accuracy(tr_loss / nb_tr_steps,
                                       acc / nb_tr_examples))

                # TODO: always eval on last batch
                # For now select the batch_size appropriately
                if (((step + 1) % steps_to_eval == 0) or (step+1)==number_of_batches_per_epoch )\
                        and args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
                    model.eval()
                    eval_loss, eval_accuracy = 0, 0
                    nb_eval_steps, nb_eval_examples = 0, 0
                    etq = tqdm(eval_dataloader, desc="Validating")
                    for batch in etq:
                        batch = tuple(t.to(device) for t in batch)

                        with torch.no_grad():
                            if not args.with_score:
                                input_ids, segment_ids, input_mask, label_ids = batch
                                outputs = model(input_ids, segment_ids,
                                                input_mask, label_ids)
                            else:
                                input_ids, segment_ids, input_mask, scores, label_ids = batch
                                outputs = model(input_ids, segment_ids,
                                                input_mask, scores, label_ids)

                            tmp_eval_loss = outputs[0]
                            logits = outputs[1]

                        logits = logits.detach().cpu().numpy()
                        label_ids = label_ids.to('cpu').numpy()
                        tmp_eval_accuracy = accuracy(logits, label_ids)

                        eval_loss += tmp_eval_loss.mean().item()
                        eval_accuracy += tmp_eval_accuracy

                        nb_eval_examples += input_ids.size(0)
                        nb_eval_steps += 1

                        etq.set_description(
                            _get_loss_accuracy(
                                eval_loss / nb_eval_steps,
                                eval_accuracy / nb_eval_examples))

                    eval_loss = eval_loss / nb_eval_steps
                    eval_accuracy = eval_accuracy / nb_eval_examples

                    logger.info(f"epoch, step | {epoch_index}, {step}")
                    logger.info("            |   Training |  Validation")
                    logger.info("accuracy    |   %.4f" %
                                (acc / nb_tr_examples) +
                                "  |   %.4f" % eval_accuracy)
                    logger.info("loss        |   %.4f" %
                                (tr_loss / nb_tr_steps) +
                                "  |   %.4f" % eval_loss)
                    best_acc = max(best_acc, eval_accuracy)

                    if eval_accuracy == best_acc:
                        best_epoch = (epoch_index, step)
                        logger.info(
                            "best validation performance so far %.4f: " %
                            best_acc + ", best epoch: " + str(best_epoch) +
                            ". saving current model to " + args.output_dir)

                        # Save a trained model, configuration and tokenizer
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self

                        # If we save using the predefined names, we can load using `from_pretrained`
                        output_model_file = os.path.join(
                            args.output_dir, WEIGHTS_NAME)
                        output_config_file = os.path.join(
                            args.output_dir, CONFIG_NAME)
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        model_to_save.config.to_json_file(output_config_file)
                        tokenizer.save_vocabulary(args.output_dir)
                model.train()

            epoch_end_time = time.time()
            logger.info(
                f"time it took to finish the epoch {epoch_index} of {args.num_train_epochs} is "
                + _show_runtime(epoch_end_time - epoch_start_time))

        # Does this even make sense to output?
        result = {
            'eval_accuracy': best_acc,
            'global_step': global_step,
            'best_epoch': best_epoch
        }
        cleanup_global_logging(stdout_handler)
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Пример #17
0
def main():
    #-----------------------------------------------------------------------------------
    sys.stdout = open('taco_k_log.txt', 'a')
    #-----------------------------------------------------------------------------------

    parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    if 'LOCAL_RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        local_rank = int(os.environ['LOCAL_RANK'])
        world_size = int(os.environ['WORLD_SIZE'])
    else:
        local_rank = args.rank
        world_size = args.world_size

    distributed_run = world_size > 1

    if local_rank == 0:
        DLLogger.init(backends=[
            JSONStreamBackend(Verbosity.DEFAULT, args.output + '/' +
                              args.log_file),
            StdOutBackend(Verbosity.VERBOSE)
        ])
    else:
        DLLogger.init(backends=[])

    for k, v in vars(args).items():
        DLLogger.log(step="PARAMETER", data={k: v})
    DLLogger.log(step="PARAMETER", data={'model_name': 'Tacotron2_PyT'})

    model_name = args.model_name
    parser = models.parse_model_args(model_name, parser)
    args, _ = parser.parse_known_args()

    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    if distributed_run:
        init_distributed(args, world_size, local_rank, args.group_name)

    torch.cuda.synchronize()
    run_start_time = time.perf_counter()

    model_config = models.get_model_config(model_name, args)
    model = models.get_model(model_name,
                             model_config,
                             to_cuda=True,
                             uniform_initialize_bn_weight=not args.
                             disable_uniform_initialize_bn_weight)

    if not args.amp_run and distributed_run:
        model = DDP(model)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    if args.amp_run:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        if distributed_run:
            model = DDP(model)

    try:
        sigma = args.sigma
    except AttributeError:
        sigma = None

    start_epoch = [0]

    if args.checkpoint_path is not "":
        load_checkpoint(model, optimizer, start_epoch, model_config,
                        args.amp_run, args.checkpoint_path)

    start_epoch = start_epoch[0]

    criterion = loss_functions.get_loss_function(model_name, sigma)

    try:
        n_frames_per_step = args.n_frames_per_step
    except AttributeError:
        n_frames_per_step = None

    collate_fn = data_functions.get_collate_function(model_name,
                                                     n_frames_per_step)
    trainset = data_functions.get_data_loader(model_name, args.dataset_path,
                                              args.training_files, args)
    if distributed_run:
        train_sampler = DistributedSampler(trainset)
        shuffle = False
    else:
        train_sampler = None
        shuffle = True

    train_loader = DataLoader(trainset,
                              num_workers=1,
                              shuffle=shuffle,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)

    valset = data_functions.get_data_loader(model_name, args.dataset_path,
                                            args.validation_files, args)

    batch_to_gpu = data_functions.get_batch_to_gpu(model_name)

    iteration = 0
    train_epoch_items_per_sec = 0.0
    val_loss = 0.0
    num_iters = 0

    model.train()

    for epoch in range(start_epoch, args.epochs):
        torch.cuda.synchronize()
        epoch_start_time = time.perf_counter()
        # used to calculate avg items/sec over epoch
        reduced_num_items_epoch = 0

        # used to calculate avg loss over epoch
        train_epoch_avg_loss = 0.0
        train_epoch_items_per_sec = 0.0

        num_iters = 0

        # if overflow at the last iteration then do not save checkpoint
        overflow = False

        if distributed_run:
            train_loader.sampler.set_epoch(epoch)

        for i, batch in enumerate(train_loader):
            torch.cuda.synchronize()
            iter_start_time = time.perf_counter()
            DLLogger.log(step=(epoch, i),
                         data={
                             'glob_iter/iters_per_epoch':
                             str(iteration) + "/" + str(len(train_loader))
                         })

            adjust_learning_rate(iteration, epoch, optimizer,
                                 args.learning_rate, args.anneal_steps,
                                 args.anneal_factor, local_rank)

            model.zero_grad()
            x, y, num_items = batch_to_gpu(batch)

            y_pred = model(x)
            loss = criterion(y_pred, y)

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, world_size).item()
                reduced_num_items = reduce_tensor(num_items.data, 1).item()
            else:
                reduced_loss = loss.item()
                reduced_num_items = num_items.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            DLLogger.log(step=(epoch, i), data={'train_loss': reduced_loss})

            train_epoch_avg_loss += reduced_loss
            num_iters += 1

            # accumulate number of items processed in this epoch
            reduced_num_items_epoch += reduced_num_items

            if args.amp_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    amp.master_params(optimizer), args.grad_clip_thresh)
            else:
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.grad_clip_thresh)

            optimizer.step()

            torch.cuda.synchronize()
            iter_stop_time = time.perf_counter()
            iter_time = iter_stop_time - iter_start_time
            items_per_sec = reduced_num_items / iter_time
            train_epoch_items_per_sec += items_per_sec

            #DLLogger.log(step=(epoch, i), data={'train_items_per_sec': items_per_sec})
            #DLLogger.log(step=(epoch, i), data={'train_iter_time': iter_time})
            iteration += 1

        torch.cuda.synchronize()
        epoch_stop_time = time.perf_counter()
        epoch_time = epoch_stop_time - epoch_start_time

        #DLLogger.log(step=(epoch,), data={'train_items_per_sec':
        #                                  (train_epoch_items_per_sec/num_iters if num_iters > 0 else 0.0)})
        DLLogger.log(step=(epoch, ),
                     data={
                         'train_loss': (train_epoch_avg_loss /
                                        num_iters if num_iters > 0 else 0.0)
                     })
        #DLLogger.log(step=(epoch,), data={'train_epoch_time': epoch_time})

        val_loss = validate(model, criterion, valset, epoch, i,
                            args.batch_size, world_size, collate_fn,
                            distributed_run, local_rank, batch_to_gpu)

        if (epoch % args.epochs_per_checkpoint
                == 0) and local_rank == 0 and args.bench_class == "":
            checkpoint_path = os.path.join(
                args.output, "checkpoint_{}_{}".format(model_name, epoch))
            save_checkpoint(model, optimizer, epoch, model_config,
                            args.amp_run, checkpoint_path)
        if local_rank == 0:
            DLLogger.flush()

    torch.cuda.synchronize()
    run_stop_time = time.perf_counter()
    run_time = run_stop_time - run_start_time
    #DLLogger.log(step=tuple(), data={'run_time': run_time})
    DLLogger.log(step=tuple(), data={'val_loss': val_loss})
    #DLLogger.log(step=tuple(), data={'train_items_per_sec':
    #                                 (train_epoch_items_per_sec/num_iters if num_iters > 0 else 0.0)})

    if local_rank == 0:
        DLLogger.flush()
Пример #18
0
def main():
    args = parse_args()

    # Devices
    if args.local_rank == -1:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        torch.distributed.init_process_group(backend="nccl")
    default_gpu = False
    if dist.is_available() and args.local_rank != -1:
        rank = dist.get_rank()
        if rank == 0:
            default_gpu = True
    else:
        default_gpu = True
    logger.info(f"device: {device} n_gpu: {n_gpu}, distributed training: {bool(args.local_rank != -1)}")

    # Load config
    config = BertConfig.from_json_file(args.config_file)

    # Load task config
    with open(args.tasks_config_file, "r") as f:
        task_cfg = edict(yaml.safe_load(f))
    task_id = args.task.strip()
    task = "TASK" + task_id
    task_name = task_cfg[task]["name"]
    base_lr = task_cfg[task]["lr"]
    if task_cfg[task].get("fusion_method", None):
        # VL-BERT pooling for VQA
        config.fusion_method = task_cfg[task]["fusion_method"]

    # Output dirs
    if args.save_name:
        prefix = "-" + args.save_name
    else:
        prefix = ""
    timestamp = (task_name + "_" + args.config_file.split("/")[1].split(".")[0] + prefix)
    save_path = os.path.join(args.output_dir, timestamp)
    if default_gpu:
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        # save all the hidden parameters.
        with open(os.path.join(save_path, "command.txt"), "w") as f:
            print(args, file=f)  # Python 3.x
            print("\n", file=f)
            print(config, file=f)

    # Seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Dataset
    batch_size, task2num_iters, dset_train, dset_val, dl_train, dl_val = LoadDataset(args, config, task_cfg, args.task)

    # Logging
    logdir = os.path.join(args.logdir, timestamp)
    tb_logger = tbLogger(logdir, save_path, [task_name], [task], task2num_iters, args.grad_acc_steps)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Model
    if "roberta" in args.bert_model:
        config.model = "roberta"
    model = BertForVLTasks.from_pretrained(args.from_pretrained, config=config, task_cfg=task_cfg, task_ids=[task])
    if task_cfg[task].get("embed_clf", None):
        logger.info('Initializing classifier weight for %s from pretrained word embeddings...' % task)
        answers_word_embed = []
        for k, v in model.state_dict().items():
            if 'bert.embeddings.word_embeddings.weight' in k:
                word_embeddings = v.detach().clone()
                break
        for answer, label in sorted(dset_train.ans2label.items()):
            a_tokens = dset_train._tokenizer.tokenize(answer)
            a_ids = dset_train._tokenizer.convert_tokens_to_ids(a_tokens)
            if len(a_ids):
                a_word_embed = (torch.stack([word_embeddings[a_id] for a_id in a_ids], dim=0)).mean(dim=0)
            else:
                a_tokens = dset_train._tokenizer.tokenize("<unk>")
                a_id = dset_train._tokenizer.convert_tokens_to_ids(a_tokens)[0]
                a_word_embed = word_embeddings[a_id]
            answers_word_embed.append(a_word_embed)
        answers_word_embed_tensor = torch.stack(answers_word_embed, dim=0)
        for name, module in model.named_modules():
            if name.endswith('clfs_dict.%s.logit_fc.3' % task):
                module.weight.data = answers_word_embed_tensor.to(device=module.weight.data.device)

    # Optimization details
    freeze_layers(model)
    criterion = LoadLoss(task_cfg, args.task)
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = []
    for key, value in dict(model.named_parameters()).items():
        if value.requires_grad:
            if "vil_" in key:
                lr = 1e-4
            else:
                lr = base_lr
            if any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [{"params": [value], "lr": lr, "weight_decay": 0.0}]
            if not any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [{"params": [value], "lr": lr, "weight_decay": args.weight_decay}]
    if default_gpu:
        print(len(list(model.named_parameters())), len(optimizer_grouped_parameters))
    if args.optim == "AdamW":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=base_lr,
                          eps=args.adam_epsilon,
                          betas=args.adam_betas,
                          correct_bias=args.adam_correct_bias)
    elif args.optim == "RAdam":
        optimizer = RAdam(optimizer_grouped_parameters, lr=base_lr)
    num_train_optim_steps = (task2num_iters[task] * args.num_train_epochs // args.grad_acc_steps)
    warmup_steps = args.warmup_steps or args.warmup_proportion * num_train_optim_steps
    if args.lr_scheduler == "warmup_linear":
        scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optim_steps)
    else:
        scheduler = WarmupConstantSchedule(optimizer, warmup_steps=warmup_steps)

    # Resume training
    start_iter_id, global_step, start_epoch, tb_logger, max_score = \
        resume(args.resume_file, model, optimizer, scheduler, tb_logger)

    # Move to GPU(s)
    model.to(device)
    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.cuda()
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model, delay_allreduce=True)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Save starting model
    save(save_path, logger, -1, model, optimizer, scheduler, global_step, tb_logger, default_gpu)

    # Print summary
    if default_gpu:
        summary_parameters(model, logger)
        print("***** Running training *****")
        print("  Num Iters: ", task2num_iters[task])
        print("  Batch size: ", batch_size)
        print("  Num steps: %d" % num_train_optim_steps)

    # Train
    for epoch_id in tqdm(range(start_epoch, args.num_train_epochs), desc="Epoch"):
        model.train()
        for step, batch in enumerate(dl_train):
            iter_id = start_iter_id + step + (epoch_id * len(dl_train))

            loss, score = ForwardModelsTrain(config, task_cfg, device, task, batch, model, criterion)

            if args.grad_acc_steps > 1:
                loss = loss / args.grad_acc_steps
            loss.backward()

            if (step + 1) % args.grad_acc_steps == 0:
                # Clip gradient
                if args.clip_grad_norm > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)

                optimizer.step()

                if global_step < warmup_steps or args.lr_scheduler == "warmup_linear":
                    scheduler.step()

                model.zero_grad()
                global_step += 1

                if default_gpu:
                    tb_logger.step_train(epoch_id, iter_id, float(loss), float(score),
                                         optimizer.param_groups[0]["lr"], task, "train")

            if (step % (20 * args.grad_acc_steps) == 0) and step != 0 and default_gpu:
                tb_logger.showLossTrain()

            # Decide whether to evaluate task
            if iter_id != 0 and iter_id % task2num_iters[task] == 0:
                score = evaluate(config, dl_val, task_cfg, device, task, model, criterion, epoch_id, default_gpu, tb_logger)
                if score > max_score:
                    max_score = score
                    save(save_path, logger, epoch_id, model, optimizer, scheduler,
                         global_step, tb_logger, default_gpu, max_score)

        save(save_path, logger, epoch_id, model, optimizer, scheduler, global_step, tb_logger, default_gpu, max_score)

    tb_logger.txt_close()
Пример #19
0
def main():

    parser = argparse.ArgumentParser(description='PyTorch Tacotron 2 Training')
    parser = parse_args(parser)
    args, _ = parser.parse_known_args()

    LOGGER.set_model_name("Tacotron2_PyT")
    LOGGER.set_backends([
        dllg.StdOutBackend(log_file=None,
                           logging_scope=dllg.TRAIN_ITER_SCOPE,
                           iteration_interval=1),
        dllg.JsonBackend(log_file=args.log_file if args.rank == 0 else None,
                         logging_scope=dllg.TRAIN_ITER_SCOPE,
                         iteration_interval=1)
    ])

    LOGGER.timed_block_start("run")
    LOGGER.register_metric(tags.TRAIN_ITERATION_LOSS,
                           metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("iter_time", metric_scope=dllg.TRAIN_ITER_SCOPE)
    LOGGER.register_metric("epoch_time", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("run_time", metric_scope=dllg.RUN_SCOPE)
    LOGGER.register_metric("val_iter_loss", metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_items/sec",
                           metric_scope=dllg.EPOCH_SCOPE)
    LOGGER.register_metric("train_epoch_avg_loss",
                           metric_scope=dllg.EPOCH_SCOPE)

    log_hardware()

    model_name = args.model_name
    parser = models.parse_model_args(model_name, parser)
    parser.parse_args()

    args = parser.parse_args()

    log_args(args)

    torch.backends.cudnn.enabled = args.cudnn_enabled
    torch.backends.cudnn.benchmark = args.cudnn_benchmark

    distributed_run = args.world_size > 1
    if distributed_run:
        init_distributed(args, args.world_size, args.rank, args.group_name)

    LOGGER.log(key=tags.RUN_START)
    run_start_time = time.time()

    model_config = models.get_model_config(model_name, args)
    model = models.get_model(model_name,
                             model_config,
                             to_fp16=args.fp16_run,
                             to_cuda=True)

    epoch_start = 0
    if args.resume:
        resume_model_path = args.resume_tacotron2_path if args.model_name == "Tacotron2" else args.resume_waveglow_path
        checkpoint = torch.load(resume_model_path, map_location='cpu')
        epoch_start = checkpoint["epoch"]
        state_dict = checkpoint['state_dict']
        if checkpoint_from_distributed(state_dict):
            state_dict = unwrap_distributed(state_dict)

        model.load_state_dict(state_dict)
        print("restore model %s" % resume_model_path)

    if distributed_run:
        model = DDP(model)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    if args.fp16_run:
        optimizer = FP16_Optimizer(
            optimizer, dynamic_loss_scale=args.dynamic_loss_scaling)

    try:
        sigma = args.sigma
    except AttributeError:
        sigma = None

    criterion = loss_functions.get_loss_function(model_name, sigma)

    try:
        n_frames_per_step = args.n_frames_per_step
    except AttributeError:
        n_frames_per_step = None

    collate_fn = data_functions.get_collate_function(model_name,
                                                     n_frames_per_step)
    trainset = data_functions.get_data_loader(model_name, args.dataset_path,
                                              args.training_files, args)
    train_sampler = DistributedSampler(trainset) if distributed_run else None
    train_loader = DataLoader(trainset,
                              num_workers=1,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              pin_memory=False,
                              drop_last=True,
                              collate_fn=collate_fn)

    valset = data_functions.get_data_loader(model_name, args.dataset_path,
                                            args.validation_files, args)

    batch_to_gpu = data_functions.get_batch_to_gpu(model_name)

    iteration = 0
    model.train()

    LOGGER.log(key=tags.TRAIN_LOOP)

    for epoch in range(epoch_start, args.epochs):
        LOGGER.epoch_start()
        epoch_start_time = time.time()
        LOGGER.log(key=tags.TRAIN_EPOCH_START, value=epoch)

        # used to calculate avg items/sec over epoch
        reduced_num_items_epoch = 0

        # used to calculate avg loss over epoch
        train_epoch_avg_loss = 0.0
        num_iters = 0

        # if overflow at the last iteration then do not save checkpoint
        overflow = False

        for i, batch in enumerate(train_loader):
            LOGGER.iteration_start()
            iter_start_time = time.time()
            LOGGER.log(key=tags.TRAIN_ITER_START, value=i)
            print("Batch: {}/{} epoch {}".format(i, len(train_loader), epoch))

            start = time.perf_counter()
            adjust_learning_rate(epoch, optimizer, args.learning_rate,
                                 args.anneal_steps, args.anneal_factor)

            model.zero_grad()
            x, y, num_items = batch_to_gpu(batch)

            if args.fp16_run:
                y_pred = model(fp32_to_fp16(x))
                loss = criterion(fp16_to_fp32(y_pred), y)
            else:
                y_pred = model(x)
                loss = criterion(y_pred, y)

            if distributed_run:
                reduced_loss = reduce_tensor(loss.data, args.world_size).item()
                reduced_num_items = reduce_tensor(num_items.data, 1).item()
            else:
                reduced_loss = loss.item()
                reduced_num_items = num_items.item()
            if np.isnan(reduced_loss):
                raise Exception("loss is NaN")

            LOGGER.log(key=tags.TRAIN_ITERATION_LOSS, value=reduced_loss)

            train_epoch_avg_loss += reduced_loss
            num_iters += 1

            # accumulate number of items processed in this epoch
            reduced_num_items_epoch += reduced_num_items

            if args.fp16_run:
                optimizer.backward(loss)
                grad_norm = optimizer.clip_master_grads(args.grad_clip_thresh)
            else:
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), args.grad_clip_thresh)

            optimizer.step()

            overflow = optimizer.overflow if args.fp16_run else False
            iteration += 1

            LOGGER.log(key=tags.TRAIN_ITER_STOP, value=i)

            iter_stop_time = time.time()
            iter_time = iter_stop_time - iter_start_time
            LOGGER.log(key="train_iter_items/sec",
                       value=(reduced_num_items / iter_time))
            LOGGER.log(key="iter_time", value=iter_time)
            LOGGER.iteration_stop()

        LOGGER.log(key=tags.TRAIN_EPOCH_STOP, value=epoch)
        epoch_stop_time = time.time()
        epoch_time = epoch_stop_time - epoch_start_time

        LOGGER.log(key="train_epoch_items/sec",
                   value=(reduced_num_items_epoch / epoch_time))
        LOGGER.log(key="train_epoch_avg_loss",
                   value=(train_epoch_avg_loss /
                          num_iters if num_iters > 0 else 0.0))
        LOGGER.log(key="epoch_time", value=epoch_time)

        LOGGER.log(key=tags.EVAL_START, value=epoch)

        validate(model, criterion, valset, iteration, args.batch_size,
                 args.world_size, collate_fn, distributed_run, args.rank,
                 batch_to_gpu, args.fp16_run)

        LOGGER.log(key=tags.EVAL_STOP, value=epoch)

        if not overflow and (epoch % args.epochs_per_checkpoint
                             == 0) and args.rank == 0:
            checkpoint_path = os.path.join(
                args.output_directory,
                "checkpoint_{}_{}".format(model_name, epoch))
            save_checkpoint(model, epoch, model_config, checkpoint_path)
            save_sample(
                model_name, model, args.waveglow_checkpoint,
                args.tacotron2_checkpoint, args.phrase_path,
                os.path.join(args.output_directory,
                             "sample_{}_{}.wav".format(model_name, iteration)),
                args.sampling_rate, args.fp16_run)

        LOGGER.epoch_stop()

    run_stop_time = time.time()
    run_time = run_stop_time - run_start_time
    LOGGER.log(key="run_time", value=run_time)
    LOGGER.log(key=tags.RUN_FINAL)

    print("training time", run_stop_time - run_start_time)

    LOGGER.timed_block_stop("run")

    if args.rank == 0:
        LOGGER.finish()
Пример #20
0
class Trainer(object):
    def __init__(self,
                 device,
                 train_data,
                 valid_data,
                 dicts,
                 opt,
                 setup_optimizer=True):
        """
        :param model:
        :param device: int (GPU id)
        :param loss_function:
        :param train_data:
        :param valid_data:
        :param dicts:
        :param opt:
        """

        # self.model = model

        # self.model = model
        # self.loss_function = loss_function
        self.device = device
        opt.node_rank = 0
        opt.nodes = 1
        self.world_size = len(opt.gpus)

        # in the case of single node distributed, it should equal self.device
        self.rank = self.device

        # make a group to later use with dist.all_reduce
        self.group = dist.group.WORLD

        self.print("[INFO] Training Options:", opt)
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=self.world_size,
                                rank=self.rank)

        self.model = None

        if self.rank == 0:
            self.train_data = train_data
            self.valid_data = valid_data
        else:
            self.train_data = copy.deepcopy(train_data)
            self.valid_data = copy.deepcopy(valid_data)

        self.dicts = dicts
        self.opt = opt
        self.cuda = (len(opt.gpus) >= 1 and opt.gpus[0] >= 0)

        assert self.cuda, "[ERROR] Training is only available on GPUs."

        self.start_time = 0

        # setting up models and others
        if opt.lfv_multilingual:
            from onmt.models.speech_recognizer.lid_loss import CrossEntropyLIDLoss
            lid_loss = CrossEntropyLIDLoss(opt.n_languages,
                                           opt.label_smoothing,
                                           opt.fast_xentropy)
            self.loss_function.add_loss_function(lid_loss, 'lid_loss')

        torch.manual_seed(self.opt.seed)

        # note: we must start creating models after ccreating the processes
        # for some reason passing a pre-created model to a process creates a "pickle" error
        if not opt.fusion:

            if self.is_main():
                print("BUILDING MODEL .... ", flush=True)
            model = build_model(opt, dicts)
            """ Building the loss function """
            if opt.ctc_loss != 0:
                loss_function = NMTAndCTCLossFunc(
                    dicts['tgt'].size(),
                    label_smoothing=opt.label_smoothing,
                    ctc_weight=opt.ctc_loss)
            elif opt.nce:
                from onmt.modules.nce.nce_loss import NCELoss
                loss_function = NCELoss(opt.model_size,
                                        dicts['tgt'].size(),
                                        noise_ratio=opt.nce_noise,
                                        logz=9,
                                        label_smoothing=opt.label_smoothing)
            else:
                loss_function = NMTLossFunc(
                    opt.model_size,
                    dicts['tgt'].size(),
                    label_smoothing=opt.label_smoothing,
                    mirror=opt.mirror_loss,
                    fast_xentropy=opt.fast_xentropy)

            # This function replaces modules with the more optimized counterparts so that it can run faster
            # Currently exp with LayerNorm
            if not opt.memory_profiling:
                # distributed is required to convert BatchNorm to SyncBatchNorm for DDP
                optimize_model(model, distributed=(self.world_size > 1))
                # optimize_model(model)

        init_model_parameters(model, opt)
        self.model = model
        self.loss_function = loss_function

        if self.cuda:
            torch.cuda.set_device(self.device)

            self.loss_function = self.loss_function.cuda(device=self.device)
            self.model = self.model.cuda(device=self.device)

            # Ensure that the distributed copies have the same initial parameters
            # Manual seed may not work the same for different GPU models.
            if self.world_size > 1:
                params = [p for p in self.model.parameters()]

                with torch.no_grad():
                    if not self.is_main():
                        for p in params:
                            p.zero_()
                    else:
                        for p in params:
                            p.add_(0)

            if self.world_size > 1:
                params = [p for p in self.model.parameters()]
                all_reduce_and_rescale_tensors(params, 1)

        if setup_optimizer:

            self.optim = onmt.Optim(opt)
            self.optim.set_parameters(self.model.parameters())

            if self.is_main():
                print("[INFO] Optimizer: ", self.optim.optimizer)

            if not self.opt.fp16:
                opt_level = "O0"
                keep_batchnorm_fp32 = False
            elif self.opt.fp16_mixed:
                opt_level = "O1"
                keep_batchnorm_fp32 = None
            else:
                opt_level = "O2"
                keep_batchnorm_fp32 = False

            if self.cuda:
                self.model, self.optim.optimizer = amp.initialize(
                    self.model,
                    self.optim.optimizer,
                    opt_level=opt_level,
                    keep_batchnorm_fp32=keep_batchnorm_fp32,
                    loss_scale="dynamic",
                    verbosity=1 if self.opt.verbose else 0)

            # wrap the model into DDP after initializing by amp
            if self.world_size > 1:
                """
                delay_allreduce is required to avoid allreduce error during backward pass
                """
                self.model = DDP(self.model,
                                 delay_allreduce=True,
                                 gradient_average=False)

                # torch DDP is more likely to work with the official amp autocast
                # self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.rank],
                #                                                        output_device=self.rank,
                #                                                        find_unused_parameters=True)

        print("[INFO] Process %d ready." % self.rank, flush=True)

    def is_main(self):

        return self.rank == 0

    def print(self, *content, flush=False):
        """
        A helper function to print only on the main process
        :param flush:
        :param content:
        :return:
        """

        if self.is_main():
            print(*content, flush=flush)
        else:
            return

    def load_encoder_weight(self, checkpoint_file):

        print("Loading pretrained models from %s" % checkpoint_file)
        checkpoint = torch.load(checkpoint_file,
                                map_location=lambda storage, loc: storage)

        pretrained_model = build_model(checkpoint['opt'], checkpoint['dicts'])
        pretrained_model.load_state_dict(checkpoint['model'])

        print("Loading pretrained encoder weights ...")
        pretrained_model.encoder.language_embedding = None
        enc_language_embedding = self.model.encoder.language_embedding
        self.model.encoder.language_embedding = None
        encoder_state_dict = pretrained_model.encoder.state_dict()

        self.model.encoder.load_state_dict(encoder_state_dict)
        self.model.encoder.language_embedding = enc_language_embedding
        return

    def load_decoder_weight(self, checkpoint_file):

        self.print("Loading pretrained models from %s" % checkpoint_file)
        checkpoint = torch.load(checkpoint_file,
                                map_location=lambda storage, loc: storage)
        chkpoint_dict = checkpoint['dicts']

        pretrained_model = build_model(checkpoint['opt'], chkpoint_dict)
        pretrained_model.load_state_dict(checkpoint['model'])

        self.print("Loading pretrained decoder weights ...")
        # first we have to remove the embeddings which probably have difference size ...
        pretrained_word_emb = pretrained_model.decoder.word_lut
        pretrained_model.decoder.word_lut = None
        pretrained_lang_emb = pretrained_model.decoder.language_embeddings
        pretrained_model.decoder.language_embeddings = None

        # actually we assume that two decoders have the same language embeddings...
        untrained_word_emb = self.model.decoder.word_lut
        self.model.decoder.word_lut = None
        untrained_lang_emb = self.model.decoder.language_embeddings
        self.model.decoder.language_embeddings = None

        decoder_state_dict = pretrained_model.decoder.state_dict()
        self.model.decoder.load_state_dict(decoder_state_dict)

        # now we load the embeddings ....
        n_copies = 0
        for token in self.dicts['tgt'].labelToIdx:

            untrained_id = self.dicts['tgt'].labelToIdx[token]

            if token in chkpoint_dict['tgt'].labelToIdx:
                pretrained_id = chkpoint_dict['tgt'].labelToIdx[token]
                untrained_word_emb.weight.data[untrained_id].copy_(
                    pretrained_word_emb.weight.data[pretrained_id])

                self.model.generator[0].linear.bias.data[untrained_id].copy_(
                    pretrained_model.generator[0].linear.bias.
                    data[pretrained_id])
                n_copies += 1

        self.print("Copied embedding for %d words" % n_copies)
        self.model.decoder.word_lut = untrained_word_emb

        # now we load the language embeddings ...
        if pretrained_lang_emb and untrained_lang_emb and 'langs' in chkpoint_dict:
            for lang in self.dicts['langs']:

                untrained_id = self.dicts['langs'][lang]
                if lang in chkpoint_dict['langs']:
                    pretrained_id = chkpoint_dict['langs'][lang]
                    untrained_lang_emb.weight.data[untrained_id].copy_(
                        pretrained_lang_emb.weight.data[pretrained_id])

        self.model.decoder.language_embeddings = untrained_lang_emb

    def warm_up(self):
        """
        Warmup the memory allocator, by attempting to fit the largest batch
        :return:
        """

        # if self.opt.memory_profiling:
        #     from pytorch_memlab import MemReporter
        #     reporter = MemReporter()
        #
        batch = self.train_data[0].get_largest_batch() if isinstance(self.train_data, list) \
            else self.train_data.get_largest_batch()
        opt = self.opt

        if self.cuda:
            batch.cuda(fp16=self.opt.fp16 and not self.opt.fp16_mixed)

        self.model.train()
        self.loss_function.train()
        self.model.zero_grad()
        oom = False

        if self.opt.memory_profiling:
            self.print("Input size: ")
            self.print(batch.size, batch.src_size, batch.tgt_size)

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None

        try:
            targets = batch.get('target_output')
            tgt_mask = None
            outputs = self.model(batch,
                                 streaming=opt.streaming,
                                 target_mask=tgt_mask,
                                 zero_encoder=opt.zero_encoder,
                                 mirror=opt.mirror_loss,
                                 streaming_state=streaming_state,
                                 nce=opt.nce)

            outputs['tgt_mask'] = tgt_mask

            loss_dict = self.loss_function(outputs, targets, model=self.model)
            loss_data = loss_dict['data']
            loss = loss_dict[
                'loss']  # a little trick to avoid gradient overflow with fp16
            full_loss = loss

            if opt.mirror_loss:
                rev_loss = loss_dict['rev_loss']
                mirror_loss = loss_dict['mirror_loss']
                full_loss = full_loss + rev_loss + mirror_loss

            # reconstruction loss
            if opt.reconstruct:
                rec_loss = loss_dict['rec_loss']
                rec_loss = rec_loss
                full_loss = full_loss + rec_loss

            if opt.lfv_multilingual:
                lid_logits = outputs['lid_logits']
                lid_labels = batch.get('target_lang')
                lid_loss_function = self.loss_function.get_loss_function(
                    'lid_loss')
                lid_loss = lid_loss_function(lid_logits, lid_labels)
                full_loss = full_loss + lid_loss

            optimizer = self.optim.optimizer

            if self.opt.memory_profiling:
                reporter.report(verbose=True)

                # for obj in gc.get_objects():
                #     try:
                #         if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                #             # print(varname(obj))
                #             # we can rule out parameter cost later
                #             # if 'parameter' not in type(obj):
                #             # if len(obj.shape) == 3:
                #             # if not isinstance(obj, torch.nn.parameter.Parameter):
                #             #     tensor = obj
                #             #     numel = tensor.
                #             print(type(obj), obj.type(), obj.size())
                #     except:
                #         pass

                # print("Memory profiling complete.")
                # print(torch.cuda.memory_summary())
                # exit()

            if self.cuda:
                with amp.scale_loss(full_loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.div_(batch.tgt_size).backward()

            if self.opt.memory_profiling:
                print('========= after backward =========')
                reporter.report(verbose=True)

            self.model.zero_grad()
            self.optim.zero_grad()
            # self.optim.step()
            # self.optim.reset()

        except RuntimeError as e:
            if 'out of memory' in str(e):
                oom = True
            else:
                raise e

        if oom:
            print(
                "* Warning: out-of-memory in warming up. This is due to the largest batch is too big for the GPU.",
                flush=True)
        else:
            print("* Warming up successfully.", flush=True)

        if self.opt.memory_profiling:
            if hasattr(torch.cuda, 'memory_summary'):
                print(torch.cuda.memory_summary())
            exit()

        pass

    def save(self, epoch, valid_ppl, itr=None):

        opt = self.opt
        model = self.model
        dicts = self.dicts

        model_state_dict = self.model.state_dict()
        optim_state_dict = self.optim.state_dict()

        if itr:
            itr_state_dict = itr.state_dict()
        else:
            itr_state_dict = None

        #  drop a checkpoint
        checkpoint = {
            'model': model_state_dict,
            'dicts': dicts,
            'opt': opt,
            'epoch': epoch,
            'itr': itr_state_dict,
            'optim': optim_state_dict,
            'amp': amp.state_dict()
        }

        file_name = '%s_ppl_%.6f_e%.2f.pt' % (opt.save_model, valid_ppl, epoch)
        print('Writing to %s' % file_name)
        torch.save(checkpoint, file_name)

        # check the save directory here
        checkpoint_dir = os.path.dirname(opt.save_model)
        existed_save_files = checkpoint_paths(checkpoint_dir)
        for save_file in existed_save_files[opt.keep_save_files:]:
            print(" * Deleting old save file %s ...." % save_file)
            os.remove(save_file)

    def eval(self, data):
        opt = self.opt
        rank = self.device
        world_size = self.world_size

        # the data iterator creates an epoch iterator
        data_iterator = generate_data_iterator(data,
                                               rank,
                                               world_size,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=1,
                                               buffer_size=opt.buffer_size)
        epoch_iterator = data_iterator.next_epoch_itr(False, pin_memory=False)

        data_size = len(epoch_iterator)
        i = 0

        self.model.eval()
        self.loss_function.eval()
        # self.model.module.reset_states()

        total_loss = zero_tensor()
        total_words = zero_tensor()

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None

        with torch.no_grad():
            while not data_iterator.end_of_epoch():
                samples = next(epoch_iterator)

                if samples:
                    batch = prepare_sample(samples,
                                           device=self.device,
                                           fp16=self.opt.fp16
                                           and not self.opt.fp16_mixed)

                    targets = batch.get('target_output')
                    tgt_mask = targets.ne(onmt.constants.PAD)
                    outputs = self.model(batch,
                                         streaming=opt.streaming,
                                         target_mask=tgt_mask,
                                         mirror=opt.mirror_loss,
                                         streaming_state=streaming_state,
                                         nce=opt.nce)

                    outputs['tgt_mask'] = tgt_mask
                    loss_dict = self.loss_function(outputs,
                                                   targets,
                                                   model=self.model,
                                                   eval=True)
                    loss_data = loss_dict['data']

                    total_loss.add_(loss_data)
                    total_words.add_(batch.tgt_size)
                    i = i + 1

        # allreduce the total loss and total words from other processes
        dist.all_reduce(total_loss, op=dist.ReduceOp.SUM, group=self.group)
        dist.all_reduce(total_words, op=dist.ReduceOp.SUM, group=self.group)

        self.model.train()
        self.loss_function.train()

        return total_loss / total_words

    def train_epoch(self, epoch, resume=False, itr_progress=None):

        global rec_ppl
        opt = self.opt
        train_data = self.train_data
        streaming = opt.streaming

        # Clear the gradients of the model
        self.model.zero_grad()
        # self.model.module.reset_states()

        dataset = train_data
        data_iterator = generate_data_iterator(dataset,
                                               self.rank,
                                               self.world_size,
                                               seed=self.opt.seed,
                                               num_workers=opt.num_workers,
                                               epoch=epoch,
                                               buffer_size=opt.buffer_size)

        # TODO: fix resume which is currently buggy
        if resume:
            data_iterator.load_state_dict(itr_progress)

        epoch_iterator = data_iterator.next_epoch_itr(
            not streaming, pin_memory=opt.pin_memory)

        total_tokens, total_loss, total_words = zero_tensor(), zero_tensor(
        ), zero_tensor()
        total_non_pads = zero_tensor()
        report_loss, report_tgt_words = zero_tensor(), zero_tensor()
        report_src_words = zero_tensor()
        report_rec_loss, report_rev_loss, report_mirror_loss = zero_tensor(
        ), zero_tensor(), zero_tensor()
        start = time.time()
        n_samples = len(data_iterator)

        counter = 0
        num_accumulated_words = zero_tensor()
        num_accumulated_sents = zero_tensor()
        grad_scaler = 1

        nan = False
        nan_counter = zero_tensor()

        if opt.streaming:
            streaming_state = self.model.init_stream()
        else:
            streaming_state = None

        i = data_iterator.iterations_in_epoch if not isinstance(
            train_data, list) else epoch_iterator.n_yielded
        i = i * self.world_size

        while not data_iterator.end_of_epoch():

            curriculum = (epoch < opt.curriculum)

            # this batch generator is not very clean atm
            # TODO: move everything to the multiGPU trainer
            samples = next(epoch_iterator)

            batch = prepare_sample(samples,
                                   device=self.device,
                                   fp16=self.opt.fp16
                                   and not self.opt.fp16_mixed)

            if opt.streaming:
                if train_data.is_new_stream():
                    streaming_state = self.model.init_stream()
            else:
                streaming_state = None

            # TODO: dealing with oom during distributed training
            oom = False

            try:
                # outputs is a dictionary containing keys/values necessary for loss function
                # can be flexibly controlled within models for easier extensibility
                counter = counter + 1
                # reduction_disabled = False if counter >= opt.update_frequency or i == (n_samples-1) else True
                # self.model.require_backward_grad_sync = not reduction_disabled

                targets = batch.get('target_output')
                tgt_mask = targets.ne(onmt.constants.PAD)
                outputs = self.model(batch,
                                     streaming=opt.streaming,
                                     target_mask=tgt_mask,
                                     zero_encoder=opt.zero_encoder,
                                     mirror=opt.mirror_loss,
                                     streaming_state=streaming_state,
                                     nce=opt.nce)

                batch_size = batch.size
                outputs['tgt_mask'] = tgt_mask

                loss_dict = self.loss_function(outputs,
                                               targets,
                                               model=self.model)
                loss_data = loss_dict['data']
                loss = loss_dict[
                    'loss']  # a little trick to avoid gradient overflow with fp16
                full_loss = loss

                if opt.mirror_loss:
                    rev_loss = loss_dict['rev_loss']
                    rev_loss_data = loss_dict['rev_loss_data']
                    mirror_loss = loss_dict['mirror_loss']
                    full_loss = full_loss + rev_loss + mirror_loss
                    mirror_loss_data = loss_dict['mirror_loss'].item()
                else:
                    rev_loss_data = None
                    mirror_loss_data = 0

                # reconstruction loss
                if opt.reconstruct:
                    rec_loss = loss_dict['rec_loss']
                    rec_loss = rec_loss
                    full_loss = full_loss + rec_loss
                    rec_loss_data = loss_dict['rec_loss_data']
                else:
                    rec_loss_data = None

                if opt.lfv_multilingual:
                    lid_logits = outputs['lid_logits']
                    lid_labels = batch.get('target_lang')
                    lid_loss_function = self.loss_function.get_loss_function(
                        'lid_loss')
                    lid_loss = lid_loss_function(lid_logits, lid_labels)
                    full_loss = full_loss + lid_loss

                optimizer = self.optim.optimizer

                # When the batch size is large, each gradient step is very easy to explode on fp16
                # Normalizing the loss to grad scaler ensures this will not happen
                full_loss.div_(grad_scaler)

                # reduction_disabled = False
                with amp.scale_loss(full_loss, optimizer) as scaled_loss:
                    scaled_loss.backward()

                del outputs

            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('[WARNING]: ran out of memory on GPU %d' % self.rank,
                          flush=True)
                    oom = True
                    torch.cuda.empty_cache()
                    loss = 0
                    if opt.streaming:  # reset stream in this case ...
                        streaming_state = self.model.init_stream()
                    raise e
                else:
                    raise e

            batch_size = batch.size

            src_size = batch.src_size
            tgt_size = batch.tgt_size
            num_accumulated_words.add_(tgt_size)
            num_accumulated_sents.add_(batch_size)

            #   We only update the parameters after getting gradients from n mini-batches
            update_flag = False
            if counter >= opt.update_frequency:
                update_flag = True
            elif i == n_samples - 1:  # update for the last minibatch
                update_flag = True

            if update_flag:
                # accumulated gradient case, in this case the update frequency
                dist.all_reduce(num_accumulated_words,
                                op=dist.ReduceOp.SUM,
                                group=self.group)

                # if (counter == 1 and self.opt.update_frequency != 1) or counter > 1:
                grad_denom = 1 / grad_scaler
                if self.opt.normalize_gradient:
                    grad_denom = num_accumulated_words.item() * grad_denom
                else:
                    grad_denom = 1
                # When we accumulate the gradients, each gradient is already normalized by a constant grad_scaler
                normalize_gradients(amp.master_params(optimizer), grad_denom)
                # Update the parameters.

                self.optim.step()
                self.optim.zero_grad()
                self.model.zero_grad()
                counter = 0
                num_accumulated_words.zero_()
                num_accumulated_sents.zero_()

                num_updates = self.optim._step
                if opt.save_every > 0 and num_updates % opt.save_every == -1 % opt.save_every:
                    valid_loss = self.eval(self.valid_data)
                    valid_ppl = math.exp(min(valid_loss, 100))

                    if self.is_main():
                        print('Validation perplexity: %g' % valid_ppl)
                        ep = float(epoch) - 1. + ((float(i) + 1.) / n_samples)
                        self.save(ep, valid_ppl, itr=data_iterator)

            num_words = tgt_size
            report_loss.add_(loss_data)
            report_tgt_words.add_(num_words)
            report_src_words.add_(src_size)
            total_loss.add_(loss_data)
            total_words.add_(num_words)
            # total_tokens += batch.get('target_output').nelement()
            # total_non_pads += batch.get('target_output').ne(onmt.constants.PAD).sum().item()
            # batch_efficiency = total_non_pads / total_tokens

            if opt.reconstruct:
                report_rec_loss.add_(rec_loss_data)

            if opt.mirror_loss:
                report_rev_loss.add_(rev_loss_data)
                report_mirror_loss.add_(mirror_loss_data)

            # control the index a little bit to ensure the log is always printed
            if i == 0 or ((i + 1) % opt.log_interval < self.world_size):

                dist.all_reduce(report_loss,
                                op=dist.ReduceOp.SUM,
                                group=self.group)
                dist.all_reduce(report_tgt_words,
                                op=dist.ReduceOp.SUM,
                                group=self.group)
                dist.all_reduce(report_src_words,
                                op=dist.ReduceOp.SUM,
                                group=self.group)

                if self.is_main():
                    log_string = ("Epoch %2d, %5d/%5d; ; ppl: %6.2f ; " %
                                  (epoch, i + 1, len(data_iterator),
                                   math.exp(report_loss.item() /
                                            report_tgt_words.item())))

                    if opt.reconstruct:
                        dist.all_reduce(report_rec_loss,
                                        op=dist.ReduceOp.SUM,
                                        group=self.group)
                        rec_ppl = math.exp(report_rec_loss.item() /
                                           report_src_words.item())
                        log_string += (" rec_ppl: %6.2f ; " % rec_ppl)

                    if opt.mirror_loss:
                        dist.all_reduce(report_rev_loss,
                                        op=dist.ReduceOp.SUM,
                                        group=self.group)
                        rev_ppl = math.exp(report_rev_loss.item() /
                                           report_tgt_words.item())
                        log_string += (" rev_ppl: %6.2f ; " % rev_ppl)
                        log_string += (" mir_loss: %6.2f ; " %
                                       (report_mirror_loss / report_tgt_words))

                    log_string += (
                        "lr: %.7f ; updates: %7d; " %
                        (self.optim.getLearningRate(), self.optim._step))

                    log_string += (
                        "%5.0f src tok/s; %5.0f tgt tok/s; " %
                        (report_src_words.item() /
                         (time.time() - start), report_tgt_words.item() /
                         (time.time() - start)))

                    log_string += ("%s elapsed" % str(
                        datetime.timedelta(seconds=int(time.time() -
                                                       self.start_time))))

                    self.print(log_string, flush=True)

                report_loss.zero_()
                report_tgt_words.zero_()
                report_src_words.zero_()
                report_rec_loss.zero_()
                report_rev_loss.zero_()
                report_mirror_loss.zero_()
                start = time.time()

            # increase i by world size
            i = i + self.world_size

        return total_loss / total_words

    # def run(self, save_file=None):
    def run(self, checkpoint=None):

        opt = self.opt

        if checkpoint is not None:
            raise NotImplementedError

            # TODO: have loading checkpoints for each process

            self.model.load_state_dict(checkpoint['model'])
            prec_opt = checkpoint['opt'] if 'opt' in checkpoint else None

            opt.reset_optim = True

            if not opt.reset_optim:

                if self.is_main():
                    print("* Loading optimizer states ... ")
                self.optim.load_state_dict(checkpoint['optim'])
                if prec_opt is not None and hasattr(prec_opt, "fp16_mixed"):
                    # Only load amp information if the mode is the same
                    # Maybe its better to change between optimization mode?
                    if opt.fp16_mixed == prec_opt.fp16_mixed and opt.fp16 == prec_opt.fp16:
                        if 'amp' in checkpoint:
                            try:
                                amp.load_state_dict(checkpoint['amp'])
                            except Exception:
                                # loading the amp state can fail
                                pass

                # Only load the progress when we use the same optimizer
                if 'itr' in checkpoint:
                    itr_progress = checkpoint['itr']
                else:
                    itr_progress = None

                resume = True
                start_epoch = checkpoint[
                    'epoch'] if 'epoch' in checkpoint else 1
                if start_epoch is None:
                    start_epoch = 1
            else:
                itr_progress = None
                resume = False
                start_epoch = 1

            del checkpoint['model']
            optim_state_dict = checkpoint['optim']
            # del checkpoint['optim']
            del checkpoint
        else:
            itr_progress = None

            resume = False
            start_epoch = 1

        if opt.load_encoder_from:
            self.load_encoder_weight(opt.load_encoder_from)
        #
        if opt.load_decoder_from:
            self.load_decoder_weight(opt.load_decoder_from)

        # if we are on a GPU: warm up the memory allocator
        if self.cuda:
            self.warm_up()

        valid_loss = self.eval(self.valid_data)
        valid_ppl = math.exp(min(valid_loss, 100))

        if self.is_main():
            print('[INFO] Validation perplexity: %g' % valid_ppl, flush=True)

        self.start_time = time.time()

        for epoch in range(start_epoch, start_epoch + opt.epochs):
            self.print('')

            #  (1) train for one epoch on the training set
            train_loss = self.train_epoch(epoch,
                                          resume=resume,
                                          itr_progress=itr_progress)
            train_ppl = math.exp(min(train_loss, 100))
            self.print('[INFO] Train perplexity: %g' % train_ppl)

            #  (2) evaluate on the validation set
            valid_loss = self.eval(self.valid_data)
            valid_ppl = math.exp(min(valid_loss, 100))

            if self.is_main():
                print('[INFO] Validation perplexity: %g' % valid_ppl)
                self.save(epoch, valid_ppl)

            itr_progress = None
            resume = False
Пример #21
0
def train(args, model):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        os.makedirs(args.output_dir, exist_ok=True)
        writer = SummaryWriter(log_dir=os.path.join("logs", args.name))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    # Prepare dataset
    train_loader, test_loader = get_loader(args)

    # Prepare optimizer and scheduler
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=0.9,
                                weight_decay=args.weight_decay)
    t_total = args.num_steps
    if args.decay_type == "cosine":
        scheduler = WarmupCosineSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)
    else:
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)

    if args.fp16:
        model, optimizer = amp.initialize(models=model,
                                          optimizers=optimizer,
                                          opt_level=args.fp16_opt_level)
        amp._amp_state.loss_scalers[0]._loss_scale = 2**20

    # Distributed training
    if args.local_rank != -1:
        model = DDP(model,
                    message_size=250000000,
                    gradient_predivide_factor=get_world_size())

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Total optimization steps = %d", args.num_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    model.zero_grad()
    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)
    losses = AverageMeter()
    global_step, best_acc = 0, 0
    while True:
        model.train()
        epoch_iterator = tqdm(train_loader,
                              desc="Training (X / X Steps) (loss=X.X)",
                              bar_format="{l_bar}{r_bar}",
                              dynamic_ncols=True,
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            batch = tuple(t.to(args.device) for t in batch)
            x, y = batch
            loss = model(x, y)

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                losses.update(loss.item() * args.gradient_accumulation_steps)
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

                epoch_iterator.set_description(
                    "Training (%d / %d Steps) (loss=%2.5f)" %
                    (global_step, t_total, losses.val))
                if args.local_rank in [-1, 0]:
                    writer.add_scalar("train/loss",
                                      scalar_value=losses.val,
                                      global_step=global_step)
                    writer.add_scalar("train/lr",
                                      scalar_value=scheduler.get_lr()[0],
                                      global_step=global_step)
                if global_step % args.eval_every == 0 and args.local_rank in [
                        -1, 0
                ]:
                    accuracy = valid(args, model, writer, test_loader,
                                     global_step)
                    if best_acc < accuracy:
                        save_model(args, model)
                        best_acc = accuracy
                    model.train()

                if global_step % t_total == 0:
                    break
        losses.reset()
        if global_step % t_total == 0:
            break

    if args.local_rank in [-1, 0]:
        writer.close()
    logger.info("Best Accuracy: \t%f" % best_acc)
    logger.info("End Training!")
Пример #22
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default='../absa_data/twitter',
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default='twitter',
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=64,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--max_entity_length",
        default=16,
        type=int,
        help=
        "The maximum entity input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=8.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--fine_tune_cnn',
                        action='store_true',
                        help='fine tune pre-trained CNN if True')
    parser.add_argument('--resnet_root',
                        default='./resnet',
                        help='path the pre-trained cnn models')
    parser.add_argument('--crop_size',
                        type=int,
                        default=224,
                        help='crop size of image')
    parser.add_argument(
        '--path_image',
        default='../pytorch-pretrained-BERT/twitter_subimages/',
        help='path to images')
    parser.add_argument(
        '--mm_model', default='TomBert', help='model name'
    )  # TomBert, TomBertNoPooling, MBert, MBertNoPooling, ResBert
    parser.add_argument('--pooling', default='first',
                        help='pooling method')  # first, cls, concat
    parser.add_argument('--bertlayer',
                        action='store_true',
                        help='whether to add another bert layer')
    parser.add_argument('--tfn',
                        action='store_true',
                        help='whether to use TFN')

    args = parser.parse_args()

    print("**************current model: " + args.mm_model +
          "******************")
    if args.mm_model == "ResBert" and args.bertlayer:
        print("add another bert layer")
        if args.mm_model == "ResBert" and args.tfn:
            print("add another tfn layer")
    elif args.mm_model == "TomBert" or args.mm_model == "MBert":
        print("pooling method: " + args.pooling)
    print("*" * 50)

    if args.task_name == "twitter":  # this refers to twitter-2017 dataset
        args.path_image = "../IJCAI2019_data/twitter2017_images/"
    elif args.task_name == "twitter2015":  # this refers to twitter-2015 dataset
        args.path_image = "../IJCAI2019_data/twitter2015_images/"
    else:
        print("The task name is not right!")

    processors = {
        "twitter2015": AbmsaProcessor,  # our twitter-2015 dataset
        "twitter": AbmsaProcessor  # our twitter-2017 dataset
    }

    num_labels_task = {
        "twitter2015": 3,  # our twitter-2015 dataset
        "twitter": 3  # our twitter-2017 dataset
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    if args.mm_model == 'ResBert':
        model = ResBertForMMSequenceClassification.from_pretrained(
            args.bert_model,
            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
            'distributed_{}'.format(args.local_rank),
            num_labels=num_labels,
            bert_layer=args.bertlayer,
            tfn=args.tfn)
    elif args.mm_model == 'MBert':
        model = MBertForMMSequenceClassification.from_pretrained(
            args.bert_model,
            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
            'distributed_{}'.format(args.local_rank),
            num_labels=num_labels,
            pooling=args.pooling)
    elif args.mm_model == 'MBertNoPooling':
        model = MBertNoPoolingForMMSequenceClassification.from_pretrained(
            args.bert_model,
            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
            'distributed_{}'.format(args.local_rank),
            num_labels=num_labels)
    elif args.mm_model == 'TomBertNoPooling':
        model = TomBertNoPoolingForMMSequenceClassification.from_pretrained(
            args.bert_model,
            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
            'distributed_{}'.format(args.local_rank),
            num_labels=num_labels)
    else:  # TomBert by default
        model = TomBertForMMSequenceClassification.from_pretrained(
            args.bert_model,
            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
            'distributed_{}'.format(args.local_rank),
            num_labels=num_labels,
            pooling=args.pooling)
    net = getattr(resnet, 'resnet152')()
    net.load_state_dict(
        torch.load(os.path.join(args.resnet_root, 'resnet152.pth')))
    encoder = myResnet(net, args.fine_tune_cnn, device)
    if args.fp16:
        model.half()
        encoder.half()
    model.to(device)
    encoder.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
        encoder = DDP(encoder)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)
        encoder = torch.nn.DataParallel(encoder)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    output_encoder_file = os.path.join(args.output_dir, "pytorch_encoder.bin")
    if args.do_train:
        train_features = convert_mm_examples_to_features(
            train_examples, label_list, args.max_seq_length,
            args.max_entity_length, tokenizer, args.crop_size, args.path_image)

        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_added_input_mask = torch.tensor(
            [f.added_input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_s2_input_ids = torch.tensor(
            [f.s2_input_ids for f in train_features], dtype=torch.long)
        all_s2_input_mask = torch.tensor(
            [f.s2_input_mask for f in train_features], dtype=torch.long)
        all_s2_segment_ids = torch.tensor(
            [f.s2_segment_ids for f in train_features], dtype=torch.long)
        all_img_feats = torch.stack([f.img_feat for f in train_features])
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids,\
                                   all_s2_input_ids, all_s2_input_mask, all_s2_segment_ids,
                                   all_img_feats, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        #'''
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_mm_examples_to_features(
            eval_examples, label_list, args.max_seq_length,
            args.max_entity_length, tokenizer, args.crop_size, args.path_image)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_added_input_mask = torch.tensor(
            [f.added_input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_s2_input_ids = torch.tensor(
            [f.s2_input_ids for f in eval_features], dtype=torch.long)
        all_s2_input_mask = torch.tensor(
            [f.s2_input_mask for f in eval_features], dtype=torch.long)
        all_s2_segment_ids = torch.tensor(
            [f.s2_segment_ids for f in eval_features], dtype=torch.long)
        all_img_feats = torch.stack([f.img_feat for f in eval_features])
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids, \
                                  all_s2_input_ids, all_s2_input_mask, all_s2_segment_ids,\
                                  all_img_feats, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        max_acc = 0.0
        #'''
        logger.info("*************** Running training ***************")
        for train_idx in trange(int(args.num_train_epochs), desc="Epoch"):

            logger.info("********** Epoch: " + str(train_idx) + " **********")
            logger.info("  Num examples = %d", len(train_examples))
            logger.info("  Batch size = %d", args.train_batch_size)
            logger.info("  Num steps = %d", num_train_steps)
            model.train()
            encoder.train()
            encoder.zero_grad()
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, added_input_mask, segment_ids, s2_input_ids, s2_input_mask, s2_segment_ids, \
                img_feats, label_ids = batch
                with torch.no_grad():
                    imgs_f, img_mean, img_att = encoder(img_feats)
                if train_idx == 0 and step == 0:
                    loss = model(input_ids, s2_input_ids, img_att, segment_ids, s2_segment_ids, input_mask, s2_input_mask, \
                             added_input_mask, label_ids, True)
                else:
                    loss = model(input_ids, s2_input_ids, img_att, segment_ids, s2_segment_ids, input_mask, s2_input_mask, \
                             added_input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            logger.info("***** Running evaluation on Dev Set*****")
            logger.info("  Num examples = %d", len(eval_examples))
            logger.info("  Batch size = %d", args.eval_batch_size)
            model.eval()
            encoder.eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0

            true_label_list = []
            pred_label_list = []

            for input_ids, input_mask, added_input_mask, segment_ids,  s2_input_ids, s2_input_mask, s2_segment_ids, \
                img_feats, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                added_input_mask = added_input_mask.to(device)
                segment_ids = segment_ids.to(device)
                s2_input_ids = s2_input_ids.to(device)
                s2_input_mask = s2_input_mask.to(device)
                s2_segment_ids = s2_segment_ids.to(device)
                img_feats = img_feats.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    imgs_f, img_mean, img_att = encoder(img_feats)
                    tmp_eval_loss = model(input_ids, s2_input_ids, img_att,
                                          segment_ids, s2_segment_ids,
                                          input_mask, s2_input_mask,
                                          added_input_mask, label_ids)
                    logits = model(input_ids, s2_input_ids, img_att,
                                   segment_ids, s2_segment_ids, input_mask,
                                   s2_input_mask, added_input_mask)

                logits = logits.detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()
                true_label_list.append(label_ids)
                pred_label_list.append(logits)
                tmp_eval_accuracy = accuracy(logits, label_ids)

                eval_loss += tmp_eval_loss.mean().item()
                eval_accuracy += tmp_eval_accuracy

                nb_eval_examples += input_ids.size(0)
                nb_eval_steps += 1

            eval_loss = eval_loss / nb_eval_steps
            eval_accuracy = eval_accuracy / nb_eval_examples
            loss = tr_loss / nb_tr_steps if args.do_train else None
            true_label = np.concatenate(true_label_list)
            pred_outputs = np.concatenate(pred_label_list)
            precision, recall, F_score = macro_f1(true_label, pred_outputs)
            result = {
                'eval_loss': eval_loss,
                'eval_accuracy': eval_accuracy,
                'f_score': F_score,
                'global_step': global_step,
                'loss': loss
            }

            logger.info("***** Dev Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))

            if eval_accuracy >= max_acc:
                # Save a trained model
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                encoder_to_save = encoder.module if hasattr(
                    encoder,
                    'module') else encoder  # Only save the model it-self
                if args.do_train:
                    torch.save(model_to_save.state_dict(), output_model_file)
                    torch.save(encoder_to_save.state_dict(),
                               output_encoder_file)
                max_acc = eval_accuracy

    # Load a trained model that you have fine-tuned
    model_state_dict = torch.load(output_model_file)
    if args.mm_model == 'ResBert':
        model = ResBertForMMSequenceClassification.from_pretrained(
            args.bert_model,
            state_dict=model_state_dict,
            num_labels=num_labels,
            bert_layer=args.bertlayer,
            tfn=args.tfn)
    elif args.mm_model == 'MBert':
        model = MBertForMMSequenceClassification.from_pretrained(
            args.bert_model,
            state_dict=model_state_dict,
            num_labels=num_labels,
            pooling=args.pooling)
    elif args.mm_model == 'MBertNoPooling':
        model = MBertNoPoolingForMMSequenceClassification.from_pretrained(
            args.bert_model,
            state_dict=model_state_dict,
            num_labels=num_labels)
    elif args.mm_model == 'TomBertNoPooling':
        model = TomBertNoPoolingForMMSequenceClassification.from_pretrained(
            args.bert_model,
            state_dict=model_state_dict,
            num_labels=num_labels)
    else:  # TomBert by default
        model = TomBertForMMSequenceClassification.from_pretrained(
            args.bert_model,
            state_dict=model_state_dict,
            num_labels=num_labels,
            pooling=args.pooling)
    model.to(device)
    encoder_state_dict = torch.load(output_encoder_file)
    encoder.load_state_dict(encoder_state_dict)
    encoder.to(device)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_test_examples(args.data_dir)
        eval_features = convert_mm_examples_to_features(
            eval_examples, label_list, args.max_seq_length,
            args.max_entity_length, tokenizer, args.crop_size, args.path_image)
        logger.info("***** Running evaluation on Test Set*****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_added_input_mask = torch.tensor(
            [f.added_input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_s2_input_ids = torch.tensor(
            [f.s2_input_ids for f in eval_features], dtype=torch.long)
        all_s2_input_mask = torch.tensor(
            [f.s2_input_mask for f in eval_features], dtype=torch.long)
        all_s2_segment_ids = torch.tensor(
            [f.s2_segment_ids for f in eval_features], dtype=torch.long)
        all_img_feats = torch.stack([f.img_feat for f in eval_features])
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_added_input_mask, all_segment_ids, \
                                  all_s2_input_ids, all_s2_input_mask, all_s2_segment_ids,
                                  all_img_feats, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        encoder.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        true_label_list = []
        pred_label_list = []

        for input_ids, input_mask, added_input_mask, segment_ids,  s2_input_ids, s2_input_mask, s2_segment_ids, \
            img_feats, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            added_input_mask = added_input_mask.to(device)
            segment_ids = segment_ids.to(device)
            s2_input_ids = s2_input_ids.to(device)
            s2_input_mask = s2_input_mask.to(device)
            s2_segment_ids = s2_segment_ids.to(device)
            img_feats = img_feats.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                imgs_f, img_mean, img_att = encoder(img_feats)
                tmp_eval_loss = model(input_ids, s2_input_ids, img_att,
                                      segment_ids, s2_segment_ids, input_mask,
                                      s2_input_mask, added_input_mask,
                                      label_ids)
                logits = model(input_ids, s2_input_ids, img_att, segment_ids,
                               s2_segment_ids, input_mask, s2_input_mask,
                               added_input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            true_label_list.append(label_ids)
            pred_label_list.append(logits)
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss / nb_tr_steps if args.do_train else None
        true_label = np.concatenate(true_label_list)
        pred_outputs = np.concatenate(pred_label_list)
        precision, recall, F_score = macro_f1(true_label, pred_outputs)
        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'precision': precision,
            'recall': recall,
            'f_score': F_score,
            'global_step': global_step,
            'loss': loss
        }

        pred_label = np.argmax(pred_outputs, axis=-1)
        fout_p = open(os.path.join(args.output_dir, "pred.txt"), 'w')
        fout_t = open(os.path.join(args.output_dir, "true.txt"), 'w')

        for i in range(len(pred_label)):
            attstr = str(pred_label[i])
            fout_p.write(attstr + '\n')
        for i in range(len(true_label)):
            attstr = str(true_label[i])
            fout_t.write(attstr + '\n')

        fout_p.close()
        fout_t.close()

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Test Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Пример #23
0
def main():
    logger.info("Running %s" % ' '.join(sys.argv))

    parser = argparse.ArgumentParser()
    ## Required parameters
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--data_dir",
        default="data/",
        type=str,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--output_dir",
        default="checkpoints/predictor/",
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--load_dir",
        type=str,
        help=
        "The output directory where the model checkpoints will be loaded during evaluation"
    )
    parser.add_argument('--load_step',
                        type=int,
                        default=0,
                        help="The checkpoint step to be loaded")
    parser.add_argument("--fact",
                        default="first",
                        choices=["first", "second"],
                        type=str,
                        help="Whether to put fact in front.")
    parser.add_argument(
        "--test_set",
        default="dev",
        choices=["dev", "test", "simple_test", "complex_test", "small_test"],
        help="Which test set is used for evaluation",
        type=str)
    parser.add_argument("--train_batch_size",
                        default=18,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=18,
                        type=int,
                        help="Total batch size for eval.")
    ## Other parameters
    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default="QQP",
                        type=str,
                        help="The name of the task to train.")
    parser.add_argument('--period', type=int, default=500)
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=256,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=20.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()
    pprint(vars(args))
    sys.stdout.flush()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "qqp": QqpProcessor,
    }

    output_modes = {
        "qqp": "classification",
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    logger.info(
        "Datasets are loaded from {}\n Outputs will be saved to {}".format(
            args.data_dir, args.output_dir))
    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))
    if args.load_dir:
        load_dir = args.load_dir
    else:
        load_dir = args.bert_model

    model = BertForSequenceClassification.from_pretrained(
        load_dir, cache_dir=cache_dir, num_labels=num_labels)

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    if args.do_train:
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        if args.fp16:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=args.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if args.loss_scale == 0:
                optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            else:
                optimizer = FP16_Optimizer(optimizer,
                                           static_loss_scale=args.loss_scale)
            warmup_linear = WarmupLinearSchedule(
                warmup=args.warmup_proportion,
                t_total=num_train_optimization_steps)

        else:
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 warmup=args.warmup_proportion,
                                 t_total=num_train_optimization_steps)

    global_step = 0
    tr_loss = 0
    best_F1 = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer, output_mode)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)

        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.float)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for epoch in range(int(args.num_train_epochs)):
            logger.info("Training epoch {} ...".format(epoch))
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(train_dataloader):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                # define a new function to compute loss values for both output_modes
                logits = model(input_ids, segment_ids, input_mask, labels=None)

                loss_fct = BCEWithLogitsLoss()

                loss = loss_fct(logits.view(-1, 1), label_ids.view(-1, 1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()

                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(
                            global_step, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    model.zero_grad()
                    global_step += 1

                if (step + 1) % args.period == 0:
                    # Save a trained model, configuration and tokenizer
                    model_to_save = model.module if hasattr(
                        model, 'module') else model

                    # If we save using the predefined names, we can load using `from_pretrained`
                    model.eval()
                    torch.set_grad_enabled(False)  # turn off gradient tracking
                    F1 = evaluate(args, model, device, processor, label_list,
                                  num_labels, tokenizer, output_mode)

                    if F1 > best_F1:
                        output_dir = os.path.join(
                            args.output_dir,
                            'save_step_{}'.format(global_step))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)

                        output_model_file = os.path.join(
                            output_dir, WEIGHTS_NAME)
                        output_config_file = os.path.join(
                            output_dir, CONFIG_NAME)

                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        model_to_save.config.to_json_file(output_config_file)
                        tokenizer.save_vocabulary(output_dir)

                        best_F1 = F1

                    model.train()  # turn on train mode
                    torch.set_grad_enabled(True)  # start gradient tracking
                    tr_loss = 0

    # do eval before exit
    if args.do_eval:
        if not args.do_train:
            global_step = 0
            output_dir = None
        save_dir = output_dir if output_dir is not None else args.load_dir
        load_step = args.load_step

        if args.load_dir is not None:
            load_step = int(
                os.path.split(args.load_dir)[1].replace('save_step_', ''))
            print("load_step = {}".format(load_step))

        F1 = evaluate(args, model, device, processor, label_list, num_labels,
                      tokenizer, output_mode)

        with open("test_result.txt", 'a') as f:
            print("load step: {} F1: {}".format(str(load_step), str(F1)),
                  file=f)
Пример #24
0
class Seq2SeqTrainer:
    """
    Seq2SeqTrainer
    """
    def __init__(self,
                 model,
                 criterion,
                 opt_config,
                 print_freq=10,
                 save_freq=1000,
                 grad_clip=float('inf'),
                 batch_first=False,
                 save_info={},
                 save_path='.',
                 train_iterations=0,
                 checkpoint_filename='checkpoint%s.pth',
                 keep_checkpoints=5,
                 math='fp32',
                 loss_scaling={},
                 cuda=True,
                 distributed=False,
                 distributed_overlap_allreduce=False,
                 distributed_overlap_num_allreduce_streams=1,
                 distributed_overlap_allreduce_messagesize=1e7,
                 distributed_overlap_allreduce_communicators=None,
                 intra_epoch_eval=0,
                 prealloc_mode='always',
                 iter_size=1,
                 verbose=False,
                 args=None):
        """
        Constructor for the Seq2SeqTrainer.

        :param model: model to train
        :param criterion: criterion (loss function)
        :param opt_config: dictionary with options for the optimizer
        :param print_freq: prints short summary every 'print_freq' iterations
        :param save_freq: saves checkpoint every 'save_freq' iterations
        :param grad_clip: coefficient for gradient clipping
        :param batch_first: if True the model uses (batch,seq,feature) tensors,
            if false the model uses (seq, batch, feature)
        :param save_info: dict with additional state stored in each checkpoint
        :param save_path: path to the directiory for checkpoints
        :param train_iterations: total number of training iterations to execute
        :param checkpoint_filename: name of files with checkpoints
        :param keep_checkpoints: max number of checkpoints to keep
        :param math: arithmetic type
        :param loss_scaling: options for dynamic loss scaling
        :param cuda: if True use cuda, if False train on cpu
        :param distributed: if True run distributed training
        :param intra_epoch_eval: number of additional eval runs within each
            training epoch
        :param prealloc_mode: controls preallocation,
            choices=['off', 'once', 'always']
        :param iter_size: number of iterations between weight updates
        :param verbose: enables verbose logging
        """
        super(Seq2SeqTrainer, self).__init__()
        self.model = model
        self.criterion = criterion
        self.epoch = 0
        self.save_info = save_info
        self.save_path = save_path
        self.save_freq = save_freq
        self.save_counter = 0
        self.checkpoint_filename = checkpoint_filename
        self.checkpoint_counter = cycle(range(keep_checkpoints))
        self.opt_config = opt_config
        self.cuda = cuda
        self.distributed = distributed
        self.print_freq = print_freq
        self.batch_first = batch_first
        self.verbose = verbose
        self.loss = None
        self.translator = None
        self.scheduler = None
        self.intra_epoch_eval = intra_epoch_eval
        self.iter_size = iter_size
        self.prealloc_mode = prealloc_mode
        self.preallocated = False
        
        # Assume multi-tensor apply if with APEX DDP
        self.args = args
        self.use_mt = (distributed  and iter_size == 1 and \
            opt_config['optimizer'] == 'FusedAdam')

        # Use APEX gradient average if gradient accumulation option enabled
        self.retain_allreduce_buffers = True if iter_size == 1 else False
        self.gradient_average = False if iter_size == 1 else True

        if cuda:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()

        params = self.model.parameters()
        if math == 'fp16':
            self.model = self.model.half()
            if distributed and self.args.distributed_weight_update != 2:
                self.model = DDP(self.model,
                                 message_size=distributed_overlap_allreduce_messagesize,
                                 delay_allreduce=(not distributed_overlap_allreduce),
                                 num_allreduce_streams=distributed_overlap_num_allreduce_streams,
                                 allreduce_communicators=distributed_overlap_allreduce_communicators,
                                 retain_allreduce_buffers=self.retain_allreduce_buffers,
                                 gradient_average=self.gradient_average)

            if self.args.distributed_weight_update == 2:
                # gradient clipping maintained by DistributedFusedAdam
                self.fp_optimizer = DwuFp16Optimizer(
                    self.model,
                    loss_scale=loss_scaling['init_scale'],
                    dls_upscale_interval=loss_scaling['upscale_interval']
                    )
                params = list(self.model.parameters())
            else:
                self.fp_optimizer = Fp16Optimizer(
                    self.model, grad_clip,
                    use_mt=self.use_mt,
                    loss_scale=loss_scaling['init_scale'],
                    dls_upscale_interval=loss_scaling['upscale_interval']
                    )
                params = self.fp_optimizer.fp32_params if isinstance(self.fp_optimizer.fp32_params, list) \
                    else [self.fp_optimizer.fp32_params]
        elif math == 'fp32':
            if distributed:
                self.model = DDP(self.model,
                                 message_size=distributed_overlap_allreduce_messagesize,
                                 delay_allreduce=(not distributed_overlap_allreduce))
            self.fp_optimizer = Fp32Optimizer(self.model, grad_clip)
            # params = self.model.parameters()

        opt_name = opt_config.pop('optimizer')
        if opt_name == 'FusedAdam':
            if math == 'fp16' or math == 'fp32':
                if self.args.distributed_weight_update == 2:
                    dwu_args = self.distributed_weight_update_config
                    self.optimizer = DistributedFusedAdam(params, max_grad_norm=grad_clip,
                                                          **dwu_args, **opt_config)
                    self.optimizer.set_global_scale(1.0) # used for grad norm clipping in step function
                else:
                    # Maintain grad norm and scaling by ourselves
                    self.optimizer = FusedAdam(params, use_mt=self.use_mt, **opt_config)
            else:
                self.optimizer = FusedAdam(params, use_mt=self.use_mt, max_grad_norm=grad_clip,
                                           amp_scale_adjustment=get_world_size(), **opt_config)
        else:
            self.optimizer = torch.optim.__dict__[opt_name](params,
                                                            **opt_config)
        logging.info(f'Using optimizer: {self.optimizer}')

        log_event(key=constants.OPT_NAME,
                  value=constants.ADAM, sync=False)
        log_event(key=constants.OPT_BASE_LR,
                  value=opt_config['lr'], sync=False)
        log_event(key=constants.OPT_ADAM_BETA_1,
                  value=self.optimizer.defaults['betas'][0], sync=False)
        log_event(key=constants.OPT_ADAM_BETA_2,
                  value=self.optimizer.defaults['betas'][1], sync=False)
        log_event(key=constants.OPT_ADAM_EPSILON,
                  value=self.optimizer.defaults['eps'], sync=False)

    @property
    def distributed_weight_update_config(self):
        """
        Return a kwarg dictionary that provides arguments for the distributed
        weight update feature.
        """
        return {
            'dwu_group_size': self.args.dwu_group_size,
            'dwu_num_blocks': self.args.dwu_num_blocks,
            'dwu_num_chunks': self.args.dwu_num_chunks,
            'dwu_num_rs_pg': self.args.dwu_num_rs_pg,
            'dwu_num_ar_pg': self.args.dwu_num_ar_pg,
            'dwu_num_ag_pg': self.args.dwu_num_ag_pg,
            'overlap_reductions': self.args.dwu_overlap_reductions,
            'full_pipeline': self.args.dwu_full_pipeline,
            'compute_L2_grad_norm': self.args.dwu_grad_norm,
            'e5m2_allgather': self.args.dwu_e5m2_allgather,
            'predivide': False,
            'flat_mt': True,
        }

    def iterate(self, src, tgt, update=True, training=True):
        """
        Performs one iteration of the training/validation.

        :param src: batch of examples from the source language
        :param tgt: batch of examples from the target language
        :param update: if True: optimizer does update of the weights
        :param training: if True: executes optimizer
        """
        src, src_length = src
        tgt, tgt_length = tgt
        src_length = torch.LongTensor(src_length)
        tgt_length = torch.LongTensor(tgt_length)

        num_toks = {}
        num_toks['tgt'] = int(sum(tgt_length - 1))
        num_toks['src'] = int(sum(src_length))

        if self.cuda:
            src = src.cuda(non_blocking=True)
            tgt = tgt.cuda(non_blocking=True)

        if self.batch_first:
            output = self.model(src, src_length, tgt[:, :-1])
            tgt_labels = tgt[:, 1:]
            T, B = output.size(1), output.size(0)
        else:
            output = self.model(src, src_length, tgt[:-1])
            tgt_labels = tgt[1:]
            T, B = output.size(0), output.size(1)

        loss = self.criterion(output.view(T * B, -1),
                              tgt_labels.contiguous().view(-1))

        loss_per_batch = torch.empty((1), dtype=torch.float, device='cpu',
            requires_grad=False, pin_memory=True)
        loss_per_batch.copy_(loss, non_blocking=True)
        loss /= (B * self.iter_size)

        if training:
            self.fp_optimizer.step(loss, self.optimizer, self.scheduler,
                                   update)

        loss_per_batch = loss_per_batch.item()
        loss_per_token = loss_per_batch / num_toks['tgt']
        loss_per_sentence = loss_per_batch / B

        return loss_per_token, loss_per_sentence, num_toks

    def feed_data(self, data_loader, training=True):
        """
        Runs training or validation on batches from data_loader.

        :param data_loader: data loader
        :param training: if True runs training else runs validation
        """
        if training:
            assert self.optimizer is not None
            eval_fractions = np.linspace(0, 1, self.intra_epoch_eval+2)[1:-1]
            iters_with_update = len(data_loader) // self.iter_size
            eval_iters = (eval_fractions * iters_with_update).astype(int)
            eval_iters = eval_iters * self.iter_size
            eval_iters = set(eval_iters)

        batch_time = AverageMeter(skip_first=False)
        data_time = AverageMeter(skip_first=False)
        losses_per_token = AverageMeter(skip_first=False)
        losses_per_sentence = AverageMeter(skip_first=False)

        tot_tok_time = AverageMeter(skip_first=False)
        src_tok_time = AverageMeter(skip_first=False)
        tgt_tok_time = AverageMeter(skip_first=False)

        batch_size = data_loader.batch_size

        end = time.time()
        for i, (src, tgt) in enumerate(data_loader):
            self.save_counter += 1
            # measure data loading time
            data_time.update(time.time() - end)

            update = False
            if i % self.iter_size == self.iter_size - 1:
                update = True

            # do a train/evaluate iteration
            stats = self.iterate(src, tgt, update, training=training)
            loss_per_token, loss_per_sentence, num_toks = stats

            # measure accuracy and record loss
            losses_per_token.update(loss_per_token, num_toks['tgt'])
            losses_per_sentence.update(loss_per_sentence, batch_size)

            # measure elapsed time
            elapsed = time.time() - end
            batch_time.update(elapsed)
            src_tok_time.update(num_toks['src'] / elapsed)
            tgt_tok_time.update(num_toks['tgt'] / elapsed)
            tot_num_toks = num_toks['tgt'] + num_toks['src']
            tot_tok_time.update(tot_num_toks / elapsed)
            self.loss = losses_per_token.avg

            if training and i in eval_iters:
                assert self.translator is not None
                test_bleu, _ = self.translator.run(calc_bleu=True,
                                                   epoch=self.epoch,
                                                   iteration=i)

                log = []
                log += [f'TRAIN [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'BLEU: {test_bleu:.2f}']
                log = '\t'.join(log)
                logging.info(log)

                self.model.train()
                self.preallocate(data_loader.batch_size,
                                 data_loader.dataset.max_len, training=True)

            if i % self.print_freq == 0:
                phase = 'TRAIN' if training else 'VALIDATION'
                log = []
                log += [f'{phase} [{self.epoch}][{i}/{len(data_loader)}]']
                log += [f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})']
                log += [f'Data {data_time.val:.2e} ({data_time.avg:.2e})']
                log += [f'Tok/s {tot_tok_time.val:.0f} ({tot_tok_time.avg:.0f})']
                if self.verbose:
                    log += [f'Src tok/s {src_tok_time.val:.0f} ({src_tok_time.avg:.0f})']
                    log += [f'Tgt tok/s {tgt_tok_time.val:.0f} ({tgt_tok_time.avg:.0f})']
                    log += [f'Loss/sentence {losses_per_sentence.val:.1f} ({losses_per_sentence.avg:.1f})']
                log += [f'Loss/tok {losses_per_token.val:.4f} ({losses_per_token.avg:.4f})']
                if training:
                    lr = self.optimizer.param_groups[0]['lr']
                    log += [f'LR {lr:.3e}']
                log = '\t'.join(log)
                logging.info(log)

            save_chkpt = (self.save_counter % self.save_freq) == (self.save_freq - 1)
            if training and save_chkpt:
                self.save_counter = 0
                self.save_info['iteration'] = i
                identifier = next(self.checkpoint_counter, -1)
                if identifier != -1:
                    with sync_workers() as rank:
                        if rank == 0:
                            self.save(identifier=identifier)

            end = time.time()

        tot_tok_time.reduce('sum')
        losses_per_token.reduce('mean')

        return losses_per_token.avg, tot_tok_time.avg

    def preallocate(self, batch_size, max_length, training):
        """
        Generates maximum sequence length batch and runs forward and backward
        pass without updating model parameters.

        :param batch_size: batch size for preallocation
        :param max_length: max sequence length for preallocation
        :param training: if True preallocates memory for backward pass
        """
        if self.prealloc_mode == 'always' or (self.prealloc_mode == 'once' and
                                              not self.preallocated):
            logging.info('Executing preallocation')
            torch.cuda.empty_cache()

            src_length = [max_length] * batch_size
            tgt_length = [max_length] * batch_size

            if self.batch_first:
                shape = (batch_size, max_length)
            else:
                shape = (max_length, batch_size)

            src = torch.full(shape, 4, dtype=torch.int64)
            tgt = torch.full(shape, 4, dtype=torch.int64)
            src = src, src_length
            tgt = tgt, tgt_length
            self.iterate(src, tgt, update=False, training=training)
            self.model.zero_grad()
            self.preallocated = True

    def optimize(self, data_loader):
        """
        Sets model in training mode, preallocates memory and runs training on
        data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(True)
        self.model.train()
        self.preallocate(data_loader.batch_size, data_loader.dataset.max_len,
                         training=True)

        output = self.feed_data(data_loader, training=True)

        self.model.zero_grad()
        return output

    def evaluate(self, data_loader):
        """
        Sets model in eval mode, disables gradients, preallocates memory and
        runs validation on data provided by data_loader.

        :param data_loader: data loader
        """
        torch.set_grad_enabled(False)
        self.model.eval()
        self.preallocate(data_loader.batch_size, data_loader.dataset.max_len,
                         training=False)

        output = self.feed_data(data_loader, training=False)

        self.model.zero_grad()
        return output

    def load(self, filename):
        """
        Loads checkpoint from filename.

        :param filename: path to the checkpoint file
        """
        if os.path.isfile(filename):
            checkpoint = torch.load(filename, map_location={'cuda:0': 'cpu'})
            if self.distributed:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            self.fp_optimizer.initialize_model(self.model)
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            assert self.scheduler is not None
            self.scheduler.load_state_dict(checkpoint['scheduler'])
            self.epoch = checkpoint['epoch']
            self.loss = checkpoint['loss']
            logging.info(f'Loaded checkpoint {filename} (epoch {self.epoch})')
        else:
            logging.error(f'Invalid checkpoint: {filename}')

    def save(self, identifier=None, is_best=False, save_all=False):
        """
        Stores checkpoint to a file.

        :param identifier: identifier for periodic checkpoint
        :param is_best: if True stores checkpoint to 'model_best.pth'
        :param save_all: if True stores checkpoint after completed training
            epoch
        """

        def write_checkpoint(state, filename):
            filename = os.path.join(self.save_path, filename)
            logging.info(f'Saving model to {filename}')
            torch.save(state, filename)

        if self.distributed:
            model_state = self.model.module.state_dict()
        else:
            model_state = self.model.state_dict()

        assert self.scheduler is not None
        state = {
            'epoch': self.epoch,
            'state_dict': model_state,
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'loss': getattr(self, 'loss', None),
        }
        state = dict(list(state.items()) + list(self.save_info.items()))

        if identifier is not None:
            filename = self.checkpoint_filename % identifier
            write_checkpoint(state, filename)

        if is_best:
            filename = 'model_best.pth'
            write_checkpoint(state, filename)

        if save_all:
            filename = f'checkpoint_epoch_{self.epoch:03d}.pth'
            write_checkpoint(state, filename)
Пример #25
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--bert_model",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--from_pretrained",
        default="bert-base-uncased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument(
        "--output_dir",
        default="save",
        type=str,
        help="The output directory where the model checkpoints will be written.",
    )
    parser.add_argument(
        "--config_file",
        default="config/bert_base_6layer_6conect.json",
        type=str,
        help="The config file which specified the model details.",
    )
    parser.add_argument(
        "--num_train_epochs",
        default=20,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--train_iter_multiplier",
        default=1.0,
        type=float,
        help="multiplier for the multi-task training.",
    )
    parser.add_argument(
        "--train_iter_gap",
        default=4,
        type=int,
        help="forward every n iteration is the validation score is not improving over the last 3 epoch, -1 means will stop",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help="Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument(
        "--no_cuda", action="store_true", help="Whether not to use CUDA when available"
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        type=bool,
        help="Whether to lower case the input text. True for uncased models, False for cased models.",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="local_rank for distributed training on gpus",
    )
    parser.add_argument(
        "--seed", type=int, default=0, help="random seed for initialization"
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumualte before performing a backward/update pass.",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=16,
        help="Number of workers in the dataloader.",
    )
    parser.add_argument(
        "--save_name", default="", type=str, help="save name for training."
    )
    parser.add_argument(
        "--in_memory",
        default=False,
        type=bool,
        help="whether use chunck for parallel training.",
    )
    parser.add_argument(
        "--optim", default="AdamW", type=str, help="what to use for the optimization."
    )
    parser.add_argument(
        "--freeze",
        default=-1,
        type=int,
        help="till which layer of textual stream of vilbert need to fixed.",
    )
    parser.add_argument(
        "--vision_scratch",
        action="store_true",
        help="whether pre-trained the image or not.",
    )
    parser.add_argument(
        "--evaluation_interval", default=1, type=int, help="evaluate very n epoch."
    )
    parser.add_argument(
        "--lr_scheduler",
        default="mannul",
        type=str,
        help="whether use learning rate scheduler.",
    )
    parser.add_argument(
        "--baseline", action="store_true", help="whether use single stream baseline."
    )
    parser.add_argument(
        "--resume_file", default="", type=str, help="Resume from checkpoint"
    )
    parser.add_argument(
        "--dynamic_attention",
        action="store_true",
        help="whether use dynamic attention.",
    )
    parser.add_argument(
        "--clean_train_sets",
        default=True,
        type=bool,
        help="whether clean train sets for multitask data.",
    )
    parser.add_argument(
        "--visual_target",
        default=0,
        type=int,
        help="which target to use for visual branch. \
        0: soft label, \
        1: regress the feature, \
        2: NCE loss.",
    )
    args = parser.parse_args()
    with open("task_config.yml", "r") as f:
        task_cfg = edict(yaml.safe_load(f))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if args.baseline:
        from pytorch_transformers.modeling_bert import BertConfig
        from src.models.basebert import BaseBertForVLTasks
    else:
        from src.models.vilbert import BertConfig
        from src.models.vilbert import VILBertForVLTasks

    name = task_cfg["name"]
    task_lr = task_cfg["lr"]

    base_lr = task_lr
    loss_scale = task_lr / base_lr

    if args.save_name:
        prefix = "-" + args.save_name
    else:
        prefix = ""
    timeStamp = (
        args.config_file.split("/")[1].split(".")[0]
        + prefix
    )
    savePath = os.path.join(args.output_dir, timeStamp)

    bert_weight_name = json.load(
        open("config/" + args.bert_model + "_weight_name.json", "r")
    )

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        )
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        torch.distributed.init_process_group(backend="nccl")

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
            device, n_gpu, bool(args.local_rank != -1), args.fp16
        )
    )

    default_gpu = False
    if dist.is_available() and args.local_rank != -1:
        rank = dist.get_rank()
        if rank == 0:
            default_gpu = True
    else:
        default_gpu = True

    if default_gpu:
        if not os.path.exists(savePath):
            os.makedirs(savePath)

    config = BertConfig.from_json_file(args.config_file)
    if default_gpu:
        # save all the hidden parameters.
        with open(os.path.join(savePath, "command.txt"), "w") as f:
            print(args, file=f)  # Python 3.x
            print("\n", file=f)
            print(config, file=f)

    # load dataset
    task_batch_size, task_num_iters, task_datasets_train, task_datasets_val, task_dataloader_train, task_dataloader_val = LoadDatasets(
        args, task_cfg
    )

    logdir = os.path.join(savePath, "logs")
    tbLogger = utils.tbLogger(
        logdir,
        savePath,
        task_num_iters,
        args.gradient_accumulation_steps,
    )

    if args.visual_target == 0:
        config.v_target_size = 1601
        config.visual_target = args.visual_target
    else:
        config.v_target_size = 2048
        config.visual_target = args.visual_target

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_ave_iter = {}
    task_stop_controller = {}
    task_ave_iter = int(
        task_cfg["num_epoch"]
        * task_num_iters
        * args.train_iter_multiplier
        / args.num_train_epochs
    )
    task_stop_controller = utils.TaskStopOnPlateau(
        mode="max",
        patience=1,
        continue_threshold=0.005,
        cooldown=1,
        threshold=0.001,
    )

    median_num_iter = task_ave_iter
    num_train_optimization_steps = (
        median_num_iter * args.num_train_epochs // args.gradient_accumulation_steps
    )
    num_labels = task_datasets_train.num_labels

    if args.dynamic_attention:
        config.dynamic_attention = True

    model = VILBertForVLTasks.from_pretrained(
        args.from_pretrained,
        config=config,
        num_labels=num_labels,
        default_gpu=default_gpu,
    )

    task_losses = LoadLosses(args, task_cfg)

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]

    if args.freeze != -1:
        bert_weight_name_filtered = []
        for name in bert_weight_name:
            if "embeddings" in name:
                bert_weight_name_filtered.append(name)
            elif "encoder" in name:
                layer_num = name.split(".")[2]
                if int(layer_num) <= args.freeze:
                    bert_weight_name_filtered.append(name)

        optimizer_grouped_parameters = []
        for key, value in dict(model.named_parameters()).items():
            if key[12:] in bert_weight_name_filtered:
                value.requires_grad = False

        if default_gpu:
            print("filtered weight")
            print(bert_weight_name_filtered)

    optimizer_grouped_parameters = []
    for key, value in dict(model.named_parameters()).items():
        if value.requires_grad:
            if "vil_" in key:
                lr = 1e-4
            else:
                if args.vision_scratch:
                    if key[12:] in bert_weight_name:
                        lr = base_lr
                    else:
                        lr = 1e-4
                else:
                    lr = base_lr
            if any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [
                    {"params": [value], "lr": lr, "weight_decay": 0.0}
                ]
            if not any(nd in key for nd in no_decay):
                optimizer_grouped_parameters += [
                    {"params": [value], "lr": lr, "weight_decay": 0.01}
                ]

    if default_gpu:
        print(len(list(model.named_parameters())), len(optimizer_grouped_parameters))

    # choose optimizer
    if args.optim == "AdamW":
        optimizer = AdamW(optimizer_grouped_parameters, lr=base_lr, correct_bias=False)
    elif args.optim == "RAdam":
        optimizer = RAdam(optimizer_grouped_parameters, lr=base_lr)

    # choose scheduler
    warmpu_steps = args.warmup_proportion * num_train_optimization_steps

    if args.lr_scheduler == "warmup_linear":
        warmup_scheduler = WarmupLinearSchedule(
            optimizer, warmup_steps=warmpu_steps, t_total=num_train_optimization_steps
        )
    else:
        warmup_scheduler = WarmupConstantSchedule(optimizer, warmup_steps=warmpu_steps)

    lr_reduce_list = np.array([5, 7])
    if args.lr_scheduler == "automatic":
        lr_scheduler = ReduceLROnPlateau(
            optimizer, mode="max", factor=0.2, patience=1, cooldown=1, threshold=0.001
        )
    elif args.lr_scheduler == "cosine":
        lr_scheduler = CosineAnnealingLR(
            optimizer, T_max=median_num_iter * args.num_train_epochs
        )
    elif args.lr_scheduler == "cosine_warm":
        lr_scheduler = CosineAnnealingWarmRestarts(
            optimizer, T_0=median_num_iter * args.num_train_epochs
        )
    elif args.lr_scheduler == "mannul":

        def lr_lambda_fun(epoch):
            return pow(0.2, np.sum(lr_reduce_list <= epoch))

        lr_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda_fun)

    startIterID = 0
    global_step = 0
    start_epoch = 0

    if args.resume_file != "" and os.path.exists(args.resume_file):
        checkpoint = torch.load(args.resume_file, map_location="cpu")
        new_dict = {}
        for attr in checkpoint["model_state_dict"]:
            if attr.startswith("module."):
                new_dict[attr.replace("module.", "", 1)] = checkpoint[
                    "model_state_dict"
                ][attr]
            else:
                new_dict[attr] = checkpoint["model_state_dict"][attr]
        model.load_state_dict(new_dict)
        warmup_scheduler.load_state_dict(checkpoint["warmup_scheduler_state_dict"])
        # lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        global_step = checkpoint["global_step"]
        start_epoch = int(checkpoint["epoch_id"]) + 1
        task_stop_controller = checkpoint["task_stop_controller"]
        tbLogger = checkpoint["tb_logger"]
        del checkpoint

    model.to(device)

    print("`==============`MODEL=============")
    print(next(model.parameters()).is_cuda)#False

    for state in optimizer.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.cuda()

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model, delay_allreduce=True)

    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    if default_gpu:
        print("***** Running training *****")
        print("  Num Iters: ", task_num_iters)
        print("  Batch size: ", task_batch_size)
        print("  Num steps: %d" % num_train_optimization_steps)

    task_iter_train = None
    task_count = 0
    for epochId in tqdm(range(start_epoch, args.num_train_epochs), desc="Epoch", ncols=100):
        model.train()
        for step in range(median_num_iter):
            iterId = startIterID + step + (epochId * median_num_iter)
            first_task = True

            is_forward = False
            if (not task_stop_controller.in_stop) or (
                iterId % args.train_iter_gap == 0
            ):
                is_forward = True

            if is_forward:
                loss, score = ForwardModelsTrain(
                    args,
                    task_cfg,
                    device,
                    task_count,
                    task_iter_train,
                    task_dataloader_train,
                    model,
                    task_losses,
                )

                loss = loss * loss_scale
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion,
                        )
                        for param_group in optimizer.param_groups:
                            param_group["lr"] = lr_this_step

                    if first_task and (
                        global_step < warmpu_steps
                        or args.lr_scheduler == "warmup_linear"
                    ):
                        warmup_scheduler.step()

                    optimizer.step()
                    model.zero_grad()
                    if first_task:
                        global_step += 1
                        first_task = False

                    if default_gpu:
                        tbLogger.step_train(
                            epochId,
                            iterId,
                            float(loss),
                            float(score),
                            optimizer.param_groups[0]["lr"],
                            "train",
                        )

            if "cosine" in args.lr_scheduler and global_step > warmpu_steps:
                lr_scheduler.step()

            if (
                step % (20 * args.gradient_accumulation_steps) == 0
                and step != 0
                and default_gpu
            ):
                tbLogger.showLossTrain()

            # decided whether to evaluate on SNLI tasks.
            if (iterId != 0 and iterId % task_num_iters == 0) or (
                epochId == args.num_train_epochs - 1 and step == median_num_iter - 1
            ):
                evaluate(
                    args,
                    task_dataloader_val,
                    task_stop_controller,
                    task_cfg,
                    device,
                    model,
                    task_losses,
                    epochId,
                    default_gpu,
                    tbLogger,
                )

        if args.lr_scheduler == "automatic":
            lr_scheduler.step(sum(val_scores.values()))
            logger.info("best average score is %3f" % lr_scheduler.best)
        elif args.lr_scheduler == "mannul":
            lr_scheduler.step()

        if epochId in lr_reduce_list:
            # reset the task_stop_controller once the lr drop
            task_stop_controller._reset()

        if default_gpu:
            # Save a trained model
            logger.info("** ** * Saving fine - tuned model ** ** * ")
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Only save the model it-self
            output_model_file = os.path.join(
                savePath, "pytorch_model_" + str(epochId) + ".bin"
            )
            output_checkpoint = os.path.join(savePath, "pytorch_ckpt_latest.tar")
            torch.save(model_to_save.state_dict(), output_model_file)
            torch.save(
                {
                    "model_state_dict": model_to_save.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "warmup_scheduler_state_dict": warmup_scheduler.state_dict(),
                    # 'lr_scheduler_state_dict': lr_scheduler.state_dict(),
                    "global_step": global_step,
                    "epoch_id": epochId,
                    "task_stop_controller": task_stop_controller,
                    "tb_logger": tbLogger,
                },
                output_checkpoint,
            )
    tbLogger.txt_close()
Пример #26
0
def train(args, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_dir = os.path.join("tensorboard", args.model_name)
        os.makedirs(tb_dir, exist_ok=True)
        tb_writer = SummaryWriter(tb_dir)

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = FusedLAMB(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer=optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=args.num_steps)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(models=model,
                                          optimizers=optimizer,
                                          opt_level=args.fp16_opt_level,
                                          cast_model_outputs=torch.float16)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = DDP(
            model,
            message_size=250000000,
            gradient_predivide_factor=torch.distributed.get_world_size())

    train_dataset = LMDataset(corpus_path=args.corpus_path,
                              tokenizer=tokenizer,
                              local_rank=args.local_rank,
                              seq_len=args.max_seq_length,
                              vocab_size=args.vocab_size,
                              mask_prob=args.mask_prob)
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Total optimization steps = %d", args.num_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    global_step = 0
    iters = 0
    model.zero_grad()
    model.train()
    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)
    while True:
        train_dataset.gen_segment()
        train_sampler = RandomSampler(
            train_dataset) if args.local_rank == -1 else DistributedSampler(
                train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size,
                                      num_workers=4,
                                      pin_memory=True)
        epoch_iterator = tqdm(
            train_dataloader,
            desc="Training (X iter) (XX / XX Steps) (Total Loss=X.X)\
                               (Generator Loss=X.X) (Discriminator Loss=X.X)",
            disable=args.local_rank not in [-1, 0])
        tr_loss = 0.0
        for step, batch in enumerate(epoch_iterator):
            batch = tuple(t.to(args.device) for t in batch)
            input_ids, input_mask, segment_ids, lm_label_ids = batch
            gen_loss, disc_loss = model(input_ids, segment_ids, input_mask,
                                        lm_label_ids)

            loss = gen_loss + disc_loss
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            mean_loss = tr_loss * args.gradient_accumulation_steps / (step + 1)

            if (step + 1) % args.gradient_accumulation_steps == 0:
                scheduler.step()  # learning rate warmup
                optimizer.step()
                for param in model.parameters():
                    param.grad = None
                global_step += 1

                epoch_iterator.set_description(
                    "Training (%d iter) (%d / %d Steps) (Mean Loss=%2.5f) (Generator Loss=%2.5f) (Discriminator Loss=%2.5f)"
                    % (iters, global_step, args.num_steps, mean_loss, gen_loss,
                       disc_loss / 50.0))

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('Mean_Loss', mean_loss, global_step)
                    tb_writer.add_scalar('Gen_Loss', gen_loss, global_step)
                    tb_writer.add_scalar('Disc_Loss', disc_loss / 50.0,
                                         global_step)

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    model_to_save = model.module if hasattr(
                        model, 'module') else model
                    model_checkpoint = os.path.join(
                        args.output_dir,
                        args.model_name + '_' + str(global_step) + '.bin')
                    model_layer_checkpoint = os.path.join(
                        args.output_dir,
                        args.model_name + '_' + str(global_step) + '_disc.bin')
                    torch.save(model_to_save.state_dict(), model_checkpoint)
                    torch.save(model_to_save.discriminator.model.state_dict(),
                               model_layer_checkpoint)
                    logger.info("Saving model checkpoint to %s",
                                args.output_dir)
            if args.num_steps > 0 and global_step == args.num_steps:
                epoch_iterator.close()
                break
        if args.num_steps > 0 and global_step == args.num_steps:
            epoch_iterator.close()
            break
        iters += 1
    if args.local_rank in [-1, 0]:
        model_to_save = model.module if hasattr(model, 'module') else model
        model_checkpoint = os.path.join(
            args.output_dir, args.model_name + '_' + str(global_step) + '.bin')
        model_layer_checkpoint = os.path.join(
            args.output_dir,
            args.model_name + '_' + str(global_step) + '_disc.bin')
        torch.save(model_to_save.state_dict(), model_checkpoint)
        torch.save(model_to_save.discriminator.model.state_dict(),
                   model_layer_checkpoint)
        logger.info("Saving model checkpoint to %s", args.output_dir)
        logger.info("End Training!")
        tb_writer.close()