def train(args, logger, tb_writer): logger.info('Args: {}'.format(json.dumps(vars(args), indent=4, sort_keys=True))) if args.local_rank in [-1, 0]: with open(os.path.join(args.save_dir, 'args.yaml'), 'w') as file: yaml.safe_dump(vars(args), file, sort_keys=False) device_id = args.local_rank if args.local_rank != -1 else 0 device = torch.device('cuda', device_id) logger.warning(f'Using GPU {args.local_rank}.') world_size = torch.distributed.get_world_size() if args.local_rank != -1 else 1 logger.info(f'Total number of GPUs used: {world_size}.') effective_batch_size = args.batch_size * world_size * args.accumulation_steps logger.info(f'Effective batch size: {effective_batch_size}.') num_train_samples_per_epoch, num_dev_samples, num_unique_train_epochs = get_data_sizes(data_dir=args.data_dir, num_epochs=args.num_epochs, logger=logger) num_optimization_steps = sum(num_train_samples_per_epoch) // world_size // args.batch_size // \ args.accumulation_steps if args.max_steps > 0: num_optimization_steps = min(num_optimization_steps, args.max_steps) logger.info(f'Total number of optimization steps: {num_optimization_steps}.') # Set random seed logger.info(f'Using random seed {args.seed}.') random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Get model if args.local_rank not in [-1, 0]: torch.distributed.barrier() logger.info(f'Loading model {args.model} for task {args.task}...') model = ModelRegistry.get_model(args.task).from_pretrained(args.model) if args.local_rank in [-1, 0]: with open(os.path.join(args.save_dir, 'config.json'), 'w') as file: json.dump(model.config.__dict__, file) if args.local_rank == 0: torch.distributed.barrier() model.to(device) # Get optimizer logger.info('Creating optimizer...') parameter_groups = get_parameter_groups(model) optimizer = AdamW(parameter_groups, lr=args.learning_rate, weight_decay=args.weight_decay, eps=1e-8) scheduler = get_lr_scheduler(optimizer, num_steps=num_optimization_steps, warmup_proportion=args.warmup_proportion) if args.amp: amp.register_half_function(torch, 'einsum') model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp_opt_level) if args.local_rank != -1: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Get dev data loader dev_data_file = os.path.join(args.data_dir, f'dev.jsonl.gz') logger.info(f'Creating dev dataset from {dev_data_file}...') dev_dataset = DatasetRegistry.get_dataset(args.task)(data_file=dev_data_file, data_size=num_dev_samples, local_rank=-1) dev_loader = DataLoader(dev_dataset, batch_size=2 * args.batch_size, num_workers=1, collate_fn=dev_dataset.collate_fn) # Get evaluator evaluator = EvaluatorRegistry.get_evaluator(args.task)(data_loader=dev_loader, logger=logger, tb_writer=tb_writer, device=device, world_size=world_size, args=args) # Get saver saver = CheckpointSaver(save_dir=args.save_dir, max_checkpoints=args.max_checkpoints, primary_metric=evaluator.primary_metric, maximize_metric=evaluator.maximize_metric, logger=logger) global_step = 0 samples_processed = 0 # Train logger.info('Training...') samples_till_eval = args.eval_every for epoch in range(1, args.num_epochs + 1): # Get train data loader for current epoch train_data_file_num = ((epoch - 1) % num_unique_train_epochs) + 1 train_data_file = os.path.join(args.data_dir, f'epoch_{train_data_file_num}.jsonl.gz') logger.info(f'Creating training dataset from {train_data_file}...') train_dataset = DatasetRegistry.get_dataset(args.task)(train_data_file, data_size=num_train_samples_per_epoch[epoch - 1], local_rank=args.local_rank, world_size=world_size) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=1, collate_fn=train_dataset.collate_fn) logger.info(f'Starting epoch {epoch}...') model.train() model.zero_grad() loss_values = defaultdict(float) samples_till_end = (num_optimization_steps - global_step) * effective_batch_size samples_in_cur_epoch = min([len(train_loader.dataset), samples_till_end]) disable_progress_bar = (args.local_rank not in [-1, 0]) with tqdm(total=samples_in_cur_epoch, disable=disable_progress_bar) as progress_bar: for step, batch in enumerate(train_loader, 1): batch = {name: tensor.to(device) for name, tensor in batch.items()} current_batch_size = batch['input_ids'].shape[0] outputs = model(**batch) loss, current_loss_values = outputs[:2] loss = loss / args.accumulation_steps for name, value in current_loss_values.items(): loss_values[name] += value / args.accumulation_steps if args.amp: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() samples_processed += current_batch_size * world_size samples_till_eval -= current_batch_size * world_size progress_bar.update(current_batch_size * world_size) if step % args.accumulation_steps == 0: current_lr = scheduler.get_last_lr()[0] if args.amp: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0) else: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() model.zero_grad() global_step += 1 # Log info progress_bar.set_postfix(epoch=epoch, step=global_step, lr=current_lr, **loss_values) if args.local_rank in [-1, 0]: tb_writer.add_scalar('train/LR', current_lr, global_step) for name, value in loss_values.items(): tb_writer.add_scalar(f'train/{name}', value, global_step) loss_values = {name: 0 for name in loss_values} if global_step == args.max_steps: logger.info('Reached maximum number of optimization steps.') break if samples_till_eval <= 0: samples_till_eval = args.eval_every eval_results = evaluator.evaluate(model, global_step) if args.local_rank in [-1, 0]: saver.save(model, global_step, eval_results) if not args.do_not_eval_after_epoch: eval_results = evaluator.evaluate(model, global_step) if args.local_rank in [-1, 0]: saver.save(model, global_step, eval_results)
optimizer_G.zero_grad() # Sample noise as generator input z = torch.randn(imgs.shape[0], 4 * opt.hidden, device=device) # Generate a batch of images gen_imgs = generator(z) # Loss measures generator's ability to fool the discriminator # Train on fake images gen_validity = discriminator(gen_imgs) gen_validity.backward(valid) optimizer_G.step() print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs, opt.n_critic * (i + 1), len(dataloader), d_loss.data[0], gen_validity.data[0])) batches_done += opt.n_critic if epoch % opt.sample_interval == 0: save_image(gen_imgs.data[:25], f'{opt.save_dir_name}/{batches_done}.png', nrow=5, normalize=True) if epoch % opt.save_interval == 0: discriminator_saver.save(discriminator, batches_done, optimizer_D, epoch) generator_saver.save(generator, batches_done, optimizer_G, epoch)