Exemple #1
0
def train(args, train_dataset, model: PreTrainedModel,
          tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples,
                            batch_first=True,
                            padding_value=tokenizer.pad_token_id)

    train_sampler = DistributedSampler(train_dataset,
                                       num_replicas=bps.size(),
                                       rank=bps.rank())
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = SGD(optimizer_grouped_parameters,
                    lr=args.learning_rate,
                    momentum=0.9)

    optimizer = bps.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())
    bps.broadcast_parameters(model.state_dict(), root_rank=0)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (bps.size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    model_to_resize = model.module if hasattr(
        model,
        "module") else model  # Take care of distributed/parallel training
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = mask_tokens(batch, tokenizer,
                                         args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs,
                            masked_lm_labels=labels) if args.mlm else model(
                                inputs, labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemple #2
0
optimizer = bps.DistributedOptimizer(
    optimizer,
    named_parameters=model.named_parameters(),
    compression=compression,
    backward_passes_per_step=args.batches_per_allreduce)

# Restore from a previous checkpoint, if initial_epoch is specified.
# BytePS: restore on the first worker which will broadcast weights to other workers.
if resume_from_epoch > 0 and bps.rank() == 0:
    filepath = args.checkpoint_format.format(epoch=resume_from_epoch)
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])

# BytePS: broadcast parameters & optimizer state.
bps.broadcast_parameters(model.state_dict(), root_rank=0)
bps.broadcast_optimizer_state(optimizer, root_rank=0)


def train(epoch):
    model.train()
    train_sampler.set_epoch(epoch)
    train_loss = Metric('train_loss')
    train_accuracy = Metric('train_accuracy')

    with tqdm(total=len(train_loader),
              desc='Train Epoch     #{}'.format(epoch + 1),
              disable=not verbose) as t:
        for batch_idx, (data, target) in enumerate(train_loader):
            adjust_learning_rate(epoch, batch_idx)
Exemple #3
0
    def build_model(self):
        """ DataLoader """

        if self.fix_aug:
            print("FIX AUG ON")
            train_transform = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.Resize((self.img_size, self.img_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
            ])
        else:
            train_transform = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.Resize((self.img_size + 30, self.img_size + 30)),
                transforms.RandomCrop(self.img_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
            ])

        test_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        self.trainA = ImageFolder(os.path.join(self.dataset_dir, self.dataset,
                                               'trainA'),
                                  train_transform,
                                  list_mode=self.list_mode)
        self.trainB = ImageFolder(os.path.join(self.dataset_dir, self.dataset,
                                               'trainB'),
                                  train_transform,
                                  list_mode=self.list_mode)
        self.testA = ImageFolder(os.path.join(self.dataset_dir, self.dataset,
                                              'testA'),
                                 test_transform,
                                 list_mode=self.list_mode)
        self.testB = ImageFolder(os.path.join(self.dataset_dir, self.dataset,
                                              'testB'),
                                 test_transform,
                                 list_mode=self.list_mode)

        trainA_sampler = torch.utils.data.distributed.DistributedSampler(
            self.trainA, num_replicas=bps.size(), rank=bps.rank())
        trainB_sampler = torch.utils.data.distributed.DistributedSampler(
            self.trainB, num_replicas=bps.size(), rank=bps.rank())
        testA_sampler = torch.utils.data.distributed.DistributedSampler(
            self.testA, num_replicas=bps.size(), rank=bps.rank())
        testB_sampler = torch.utils.data.distributed.DistributedSampler(
            self.testB, num_replicas=bps.size(), rank=bps.rank())

        self.trainA_loader = DataLoader(self.trainA,
                                        batch_size=self.batch_size,
                                        sampler=trainA_sampler,
                                        num_workers=1)
        self.trainB_loader = DataLoader(self.trainB,
                                        batch_size=self.batch_size,
                                        sampler=trainB_sampler,
                                        num_workers=1)
        self.testA_loader = DataLoader(self.testA,
                                       batch_size=1,
                                       sampler=testA_sampler)
        self.testB_loader = DataLoader(self.testB,
                                       batch_size=1,
                                       sampler=testB_sampler)
        """ Define Generator, Discriminator """
        self.genA2B = ResnetGenerator(input_nc=3,
                                      output_nc=3,
                                      ngf=self.ch,
                                      n_blocks=self.n_res,
                                      img_size=self.img_size,
                                      light=self.light).to(self.device)
        self.genB2A = ResnetGenerator(input_nc=3,
                                      output_nc=3,
                                      ngf=self.ch,
                                      n_blocks=self.n_res,
                                      img_size=self.img_size,
                                      light=self.light).to(self.device)
        self.disGA = Discriminator(input_nc=3, ndf=self.ch,
                                   n_layers=7).to(self.device)
        self.disGB = Discriminator(input_nc=3, ndf=self.ch,
                                   n_layers=7).to(self.device)
        self.disLA = Discriminator(input_nc=3, ndf=self.ch,
                                   n_layers=5).to(self.device)
        self.disLB = Discriminator(input_nc=3, ndf=self.ch,
                                   n_layers=5).to(self.device)
        """ Define Loss """
        self.L1_loss = nn.L1Loss().to(self.device)
        self.MSE_loss = nn.MSELoss().to(self.device)
        self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device)

        gen_named_parameters = []
        dis_named_parameters = []
        for n, p in (list(self.genA2B.named_parameters(prefix='genA2B')) +
                     list(self.genB2A.named_parameters(prefix='genB2A'))):
            gen_named_parameters.append((n, p))
        for n, p in (list(self.disGA.named_parameters(prefix='disGA')) +
                     list(self.disGB.named_parameters(prefix='disGB')) +
                     list(self.disLA.named_parameters(prefix='disLA')) +
                     list(self.disLB.named_parameters(prefix='disLB'))):
            dis_named_parameters.append((n, p))

        gen_state_dict = OrderedDict(
            [("genA2B." + k, v) for k, v in self.genA2B.state_dict().items()] +
            [("genB2A." + k, v) for k, v in self.genB2A.state_dict().items()])
        dis_state_dict = OrderedDict(
            [("disGA." + k, v) for k, v in self.disGA.state_dict().items()] +
            [("disGB." + k, v) for k, v in self.disGB.state_dict().items()] +
            [("disLA." + k, v) for k, v in self.disLA.state_dict().items()] +
            [("disLB." + k, v) for k, v in self.disLB.state_dict().items()])

        bps.broadcast_parameters(gen_state_dict, root_rank=0)
        bps.broadcast_parameters(dis_state_dict, root_rank=0)
        """ Trainer """
        self.G_optim = torch.optim.Adam(itertools.chain(
            self.genA2B.parameters(), self.genB2A.parameters()),
                                        lr=self.lr,
                                        betas=(0.5, 0.999),
                                        weight_decay=self.weight_decay)
        self.D_optim = torch.optim.Adam(itertools.chain(
            self.disGA.parameters(), self.disGB.parameters(),
            self.disLA.parameters(), self.disLB.parameters()),
                                        lr=self.lr,
                                        betas=(0.5, 0.999),
                                        weight_decay=self.weight_decay)

        named_parameters = []
        for n, p in list(self.genA2B.named_parameters()):
            named_parameters.append(("genA2B." + n, p))
        for n, p in list(self.genB2A.named_parameters()):
            named_parameters.append(("genB2A." + n, p))

        self.G_optim = bps.DistributedOptimizer(
            self.G_optim,
            named_parameters=gen_named_parameters,
            compression=bps.Compression.none)

        self.D_optim = bps.DistributedOptimizer(
            self.D_optim,
            named_parameters=dis_named_parameters,
            compression=bps.Compression.none)

        self.G_optim._handles.clear()
        self.D_optim._handles.clear()
        """ Define Rho clipper to constraint the value of rho in AdaILN and ILN"""
        self.Rho_clipper = RhoClipper(0, 1)