예제 #1
0
def train(args, logger, tb_writer):
    logger.info('Args: {}'.format(json.dumps(vars(args), indent=4, sort_keys=True)))
    if args.local_rank in [-1, 0]:
        with open(os.path.join(args.save_dir, 'args.yaml'), 'w') as file:
            yaml.safe_dump(vars(args), file, sort_keys=False)

    device_id = args.local_rank if args.local_rank != -1 else 0
    device = torch.device('cuda', device_id)
    logger.warning(f'Using GPU {args.local_rank}.')

    world_size = torch.distributed.get_world_size() if args.local_rank != -1 else 1
    logger.info(f'Total number of GPUs used: {world_size}.')
    effective_batch_size = args.batch_size * world_size * args.accumulation_steps
    logger.info(f'Effective batch size: {effective_batch_size}.')

    num_train_samples_per_epoch, num_dev_samples, num_unique_train_epochs = get_data_sizes(data_dir=args.data_dir,
                                                                                           num_epochs=args.num_epochs,
                                                                                           logger=logger)
    num_optimization_steps = sum(num_train_samples_per_epoch) // world_size // args.batch_size // \
                             args.accumulation_steps
    if args.max_steps > 0:
        num_optimization_steps = min(num_optimization_steps, args.max_steps)
    logger.info(f'Total number of optimization steps: {num_optimization_steps}.')

    # Set random seed
    logger.info(f'Using random seed {args.seed}.')
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Get model
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()

    logger.info(f'Loading model {args.model} for task {args.task}...')
    model = ModelRegistry.get_model(args.task).from_pretrained(args.model)

    if args.local_rank in [-1, 0]:
        with open(os.path.join(args.save_dir, 'config.json'), 'w') as file:
            json.dump(model.config.__dict__, file)

    if args.local_rank == 0:
        torch.distributed.barrier()

    model.to(device)

    # Get optimizer
    logger.info('Creating optimizer...')
    parameter_groups = get_parameter_groups(model)
    optimizer = AdamW(parameter_groups, lr=args.learning_rate, weight_decay=args.weight_decay, eps=1e-8)
    scheduler = get_lr_scheduler(optimizer, num_steps=num_optimization_steps, warmup_proportion=args.warmup_proportion)

    if args.amp:
        amp.register_half_function(torch, 'einsum')
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level)

    if args.local_rank != -1:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank,
                                        find_unused_parameters=True)

    # Get dev data loader
    dev_data_file = os.path.join(args.data_dir, f'dev.jsonl.gz')
    logger.info(f'Creating dev dataset from {dev_data_file}...')
    dev_dataset = DatasetRegistry.get_dataset(args.task)(data_file=dev_data_file,
                                                         data_size=num_dev_samples,
                                                         local_rank=-1)
    dev_loader = DataLoader(dev_dataset,
                            batch_size=2 * args.batch_size,
                            num_workers=1,
                            collate_fn=dev_dataset.collate_fn)

    # Get evaluator
    evaluator = EvaluatorRegistry.get_evaluator(args.task)(data_loader=dev_loader,
                                                           logger=logger,
                                                           tb_writer=tb_writer,
                                                           device=device,
                                                           world_size=world_size,
                                                           args=args)

    # Get saver
    saver = CheckpointSaver(save_dir=args.save_dir,
                            max_checkpoints=args.max_checkpoints,
                            primary_metric=evaluator.primary_metric,
                            maximize_metric=evaluator.maximize_metric,
                            logger=logger)

    global_step = 0
    samples_processed = 0

    # Train
    logger.info('Training...')
    samples_till_eval = args.eval_every
    for epoch in range(1, args.num_epochs + 1):
        # Get train data loader for current epoch
        train_data_file_num = ((epoch - 1) % num_unique_train_epochs) + 1
        train_data_file = os.path.join(args.data_dir, f'epoch_{train_data_file_num}.jsonl.gz')
        logger.info(f'Creating training dataset from {train_data_file}...')
        train_dataset = DatasetRegistry.get_dataset(args.task)(train_data_file,
                                                               data_size=num_train_samples_per_epoch[epoch - 1],
                                                               local_rank=args.local_rank,
                                                               world_size=world_size)
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=1,
                                  collate_fn=train_dataset.collate_fn)

        logger.info(f'Starting epoch {epoch}...')
        model.train()
        model.zero_grad()
        loss_values = defaultdict(float)
        samples_till_end = (num_optimization_steps - global_step) * effective_batch_size
        samples_in_cur_epoch = min([len(train_loader.dataset), samples_till_end])
        disable_progress_bar = (args.local_rank not in [-1, 0])
        with tqdm(total=samples_in_cur_epoch, disable=disable_progress_bar) as progress_bar:
            for step, batch in enumerate(train_loader, 1):
                batch = {name: tensor.to(device) for name, tensor in batch.items()}
                current_batch_size = batch['input_ids'].shape[0]

                outputs = model(**batch)
                loss, current_loss_values = outputs[:2]

                loss = loss / args.accumulation_steps
                for name, value in current_loss_values.items():
                    loss_values[name] += value / args.accumulation_steps

                if args.amp:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                samples_processed += current_batch_size * world_size
                samples_till_eval -= current_batch_size * world_size
                progress_bar.update(current_batch_size * world_size)

                if step % args.accumulation_steps == 0:
                    current_lr = scheduler.get_last_lr()[0]

                    if args.amp:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                    global_step += 1

                    # Log info
                    progress_bar.set_postfix(epoch=epoch, step=global_step, lr=current_lr, **loss_values)
                    if args.local_rank in [-1, 0]:
                        tb_writer.add_scalar('train/LR', current_lr, global_step)
                        for name, value in loss_values.items():
                            tb_writer.add_scalar(f'train/{name}', value, global_step)
                    loss_values = {name: 0 for name in loss_values}

                    if global_step == args.max_steps:
                        logger.info('Reached maximum number of optimization steps.')
                        break

                    if samples_till_eval <= 0:
                        samples_till_eval = args.eval_every
                        eval_results = evaluator.evaluate(model, global_step)
                        if args.local_rank in [-1, 0]:
                            saver.save(model, global_step, eval_results)

            if not args.do_not_eval_after_epoch:
                eval_results = evaluator.evaluate(model, global_step)
                if args.local_rank in [-1, 0]:
                    saver.save(model, global_step, eval_results)
예제 #2
0
        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = torch.randn(imgs.shape[0], 4 * opt.hidden, device=device)

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        # Train on fake images
        gen_validity = discriminator(gen_imgs)
        gen_validity.backward(valid)

        optimizer_G.step()

        print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" %
              (epoch, opt.n_epochs, opt.n_critic *
               (i + 1), len(dataloader), d_loss.data[0], gen_validity.data[0]))

        batches_done += opt.n_critic
    if epoch % opt.sample_interval == 0:
        save_image(gen_imgs.data[:25],
                   f'{opt.save_dir_name}/{batches_done}.png',
                   nrow=5,
                   normalize=True)
    if epoch % opt.save_interval == 0:
        discriminator_saver.save(discriminator, batches_done, optimizer_D,
                                 epoch)
        generator_saver.save(generator, batches_done, optimizer_G, epoch)