예제 #1
0
def train(model, input_data, output_data, learning_rate=1e-3, epochs=1000):
  """Fits model to training data and returns train loss."""

  optimizer = create_optimizer(model, learning_rate=learning_rate, weight_decay=0)
    
  loss_fn_kwargs={}

  for epoch in range(epochs):
    optimizer = train_step(optimizer, input_data, output_data, mse_loss, loss_fn_kwargs)

  preds = jnp.squeeze(optimizer.target(input_data), axis=1)

  train_loss = jnp.mean(jnp.square(preds-output_data))

  return train_loss
예제 #2
0
def train(model,
          optimizer,
          lr_scheduler,
          train_data_iterator,
          val_data_iterator,
          timers,
          args,
          summary_writer=None):
    """Train the model."""

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_lm_loss = 0.0

    # Iterations.
    skipped_iters = 0

    timers('interval time').start()
    report_memory_flag = True
    mems = []
    while args.iteration < args.train_iters:

        lm_loss, skipped_iter, mems = train_step(
            train_data_iterator,
            model,
            optimizer,
            lr_scheduler,
            args,
            timers,
            mems=mems,
            forward_step_func=forward_step)
        skipped_iters += skipped_iter
        args.iteration += 1

        # Update losses.
        total_lm_loss += lm_loss.data.detach().float()

        # Logging.
        if args.iteration % args.log_interval == 0:
            learning_rate = optimizer.param_groups[0]['lr']
            avg_lm_loss = total_lm_loss.item() / args.log_interval
            elapsed_time = timers('interval time').elapsed()
            report_iteration_metrics(summary_writer, optimizer, learning_rate,
                                     avg_lm_loss,
                                     elapsed_time * 1000.0 / args.log_interval,
                                     args.iteration, args.train_iters, args)
            total_lm_loss = 0.0
            if report_memory_flag:
                report_memory('after {} iterations'.format(args.iteration))
                report_memory_flag = False
            # for i in range(torch.distributed.get_world_size()):
            #     if i == torch.distributed.get_rank():
            #         print(get_hostname())
            #         timers.log(['forward', 'backward', 'optimizer',
            #                     'batch generator', 'data loader'],
            #                    normalizer=args.log_interval, reset=False)
            #     torch.distributed.barrier()
            if args.deepspeed or args.DDP_impl == 'torch':
                timers.log([
                    'forward', 'backward', 'optimizer', 'batch generator',
                    'data loader'
                ],
                           normalizer=args.log_interval)
            else:
                timers.log([
                    'forward', 'backward', 'allreduce', 'optimizer',
                    'batch generator', 'data loader'
                ],
                           normalizer=args.log_interval)
        # Checkpointing
        if args.save and args.save_interval and args.iteration % args.save_interval == 0:
            save_checkpoint(args.iteration, model, optimizer, lr_scheduler,
                            args)

        # Evaluation
        if args.eval_interval and args.iteration % args.eval_interval == 0 and args.do_valid:
            prefix = 'iteration {}'.format(args.iteration)
            evaluate_and_print_results(prefix,
                                       val_data_iterator,
                                       model,
                                       args,
                                       timers,
                                       verbose=False,
                                       step=args.iteration,
                                       summary_writer=summary_writer,
                                       forward_step_func=forward_step)

    return args.iteration, skipped_iters
예제 #3
0
def main():
    #  load datasets
    train_loader, test_loader = load_dataset(args.train_set, args.batch_size)

    # initialize model
    if args.model_file:
        try:
            total_examples, fixed_noise, gen_losses, disc_losses, gen_loss_per_epoch, disc_loss_per_epoch, \
            prev_epoch, gan, disc_optimizer, gen_optimizer, memory, use_EM, use_mcgn \
                = load_model(args.model_file, args.cuda, args.learning_rate, args.beta_0, args.beta_1)
            print('model loaded successfully! resuming from step {}'.format(
                prev_epoch))
            args.memory = memory  # prevents any contradictions during loading
            args.use_EM = use_EM
            args.use_mcgn = use_mcgn
        except:
            print('could not load model! creating new model...')
            args.model_file = None

    if not args.model_file:
        print('creating new model...')
        total_examples, fixed_noise, gen_losses, disc_losses, gen_loss_per_epoch, disc_loss_per_epoch, \
        prev_epoch, gan, disc_optimizer, gen_optimizer \
            = create_new_model(args.train_set, args.cuda, args.learning_rate, args.beta_0,
                               args.beta_1, args.memory, use_mcgn=args.use_mcgn)

    # results save folder
    gen_images_dir = 'results/generated_images'
    train_summaries_dir = 'results/training_summaries'
    checkpoint_dir = 'results/checkpoints'
    if not os.path.isdir('results'):
        os.mkdir('results')
    if not os.path.isdir(gen_images_dir):
        os.mkdir(gen_images_dir)
    if not os.path.isdir(train_summaries_dir):
        os.mkdir(train_summaries_dir)
    if not os.path.isdir(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    np.random.seed(
        args.seed
    )  # reset training seed to ensure that batches remain the same between runs!

    try:
        for epoch in range(prev_epoch, args.n_epochs):

            disc_losses_epoch = []
            gen_losses_epoch = []
            for idx, (true_batch, _) in enumerate(train_loader):
                #time.sleep(1)
                disc_train_loss, gen_train_loss, disc_true_accuracy, disc_fake_accuracy \
                        = train_step(gan=gan, batch_size=args.batch_size, is_cuda=args.cuda, true_batch=true_batch,
                                     grad_clip=args.grad_clip, disc_optimizer=disc_optimizer,
                                     gen_optimizer=gen_optimizer, memory=args.memory, use_EM=args.use_EM)

                if (total_examples != 0) and (total_examples %
                                              args.display_result_every == 0):
                    print(
                        'epoch {}: step {}/{} disc true acc: {:.4f} disc fake acc: {:.4f} '
                        'disc loss: {:.4f}, gen loss: {:.4f}'.format(
                            epoch + 1, idx + 1, len(train_loader),
                            disc_true_accuracy, disc_fake_accuracy,
                            disc_train_loss.item(), gen_train_loss.item()))

                if args.memory and args.verbose:
                    next_h, next_k, next_a, next_v = gan.memory.get_info_for_logging(
                    )

                    if (total_examples != 0) and (
                            total_examples % args.display_result_every == 0):
                        print(
                            'avg hist: {:.4f} avg age: {:.5f} avg key val: {:.4f}'
                            .format(next_h.mean(), next_a.mean(),
                                    next_v.mean()))
                        print(
                            'min: {:.3f}, max: {:.3f}, median: {:.3f} mean: {:.3f}'
                            .format(next_h.min().item(),
                                    next_h.max().item(),
                                    next_h.median().item(),
                                    next_h.mean().item()))

                # Checkpoint model
                total_examples += args.batch_size

                if (total_examples != 0) and (total_examples %
                                              args.checkpoint_interval == 0):

                    disc_losses.extend(disc_losses_epoch)
                    gen_losses.extend(gen_losses_epoch)
                    save_all(total_examples=total_examples,
                             fixed_noise=fixed_noise,
                             gan=gan,
                             disc_loss_per_epoch=disc_loss_per_epoch,
                             gen_loss_per_epoch=gen_loss_per_epoch,
                             gen_losses=gen_losses,
                             disc_losses=disc_losses,
                             epoch=epoch,
                             checkpoint_dir=checkpoint_dir,
                             is_cuda=args.cuda,
                             gen_images_dir=gen_images_dir,
                             train_summaries_dir=train_summaries_dir,
                             disc_optimizer=disc_optimizer,
                             gen_optimizer=gen_optimizer,
                             train_set=args.train_set,
                             memory=args.memory,
                             use_EM=args.use_EM,
                             use_mcgn=args.use_mcgn)

                #  Collect information per epoch
                disc_losses_epoch.append(disc_train_loss.item())
                gen_losses_epoch.append(gen_train_loss.item())

            disc_loss_per_epoch.append(np.average(disc_losses_epoch))
            gen_loss_per_epoch.append(np.average(gen_losses_epoch))

            # Save epoch learning curve
            save_learning_curve_epoch(gen_losses=gen_loss_per_epoch,
                                      disc_losses=disc_loss_per_epoch,
                                      total_epochs=epoch + 1,
                                      directory=train_summaries_dir)
            print("Saved learning curves!")

            print('epoch {}/{} disc loss: {:.4f}, gen loss: {:.4f}'.format(
                epoch + 1, args.n_epochs,
                np.array(disc_losses_epoch).mean(),
                np.array(gen_losses_epoch).mean()))

            disc_losses.extend(disc_losses_epoch)
            gen_losses.extend(gen_losses_epoch)

    except KeyboardInterrupt:
        print("Saving before quit...")
        save_all(total_examples=total_examples,
                 fixed_noise=fixed_noise,
                 gan=gan,
                 disc_loss_per_epoch=disc_loss_per_epoch,
                 gen_loss_per_epoch=gen_loss_per_epoch,
                 gen_losses=gen_losses,
                 disc_losses=disc_losses,
                 epoch=epoch,
                 checkpoint_dir=checkpoint_dir,
                 is_cuda=args.cuda,
                 gen_images_dir=gen_images_dir,
                 train_summaries_dir=train_summaries_dir,
                 disc_optimizer=disc_optimizer,
                 gen_optimizer=gen_optimizer,
                 train_set=args.train_set,
                 memory=args.memory,
                 use_EM=args.use_EM,
                 use_mcgn=args.use_mcgn)
예제 #4
0
def _train(model,
           optimizer,
           lr_scheduler,
           forward_step,
           train_dataloader,
           valid_dataloader,
           end_of_epoch_callback,
           args,
           timers,
           summary_writer=None):
    """Train the model."""

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    args.iteration = 0
    total_lm_loss = 0.0
    best_score, best_iteration = 0, None
    # Starting epoch and iteration
    start_epoch = args.iteration // args.train_iters_per_epoch
    start_iteration = args.iteration % args.train_iters_per_epoch
    if not args.block_lm_ratio:
        valid_dataloader = valid_dataloader[0]
    # For each remaining epoch
    timers('interval time').start()
    for epoch in range(start_epoch, args.epochs):
        print_rank_0('working on epoch {} ...'.format(epoch))

        # Set the data loader epoch to shuffle the index iterator.
        if mpu.get_model_parallel_rank() == 0:
            train_dataloader[0].sampler.set_epoch(args.seed + epoch)

        # For all the batches in the dataset.
        for iteration_, batch in enumerate(train_dataloader[0]):

            # Ignore the iterations before starting value
            if iteration_ < start_iteration:
                continue
            # Set to zero so the next epoch does not skip any batches.
            start_iteration = 0

            # Train for one step.
            if args.block_lm_ratio > 0.0:
                data = (batch, train_dataloader[1])
            else:
                data = batch
            lm_loss, skipped_iter, _ = train_step(
                data,
                model,
                optimizer,
                lr_scheduler,
                args,
                timers,
                forward_step_func=forward_step,
                single_step=True)
            args.iteration += 1
            total_lm_loss += lm_loss.data.detach().float()

            # Logging.
            if args.iteration % args.log_interval == 0:
                learning_rate = optimizer.param_groups[0]['lr']
                avg_lm_loss = total_lm_loss.item() / args.log_interval
                elapsed_time = timers('interval time').elapsed()
                timers.log([
                    'forward', 'backward', 'allreduce', 'optimizer',
                    'batch generator'
                ],
                           normalizer=args.log_interval)
                report_iteration_metrics(
                    summary_writer, optimizer, learning_rate, avg_lm_loss,
                    elapsed_time * 1000.0 / args.log_interval, args.iteration,
                    args.train_iters, args)
                total_lm_loss = 0.0

            # Evaluation
            if args.eval_interval and valid_dataloader is not None and args.iteration % args.eval_interval == 0:
                prefix = 'iteration {}'.format(args.iteration)
                evaluate_and_print_results(prefix,
                                           valid_dataloader,
                                           model,
                                           args,
                                           timers,
                                           step=args.iteration,
                                           verbose=False,
                                           forward_step_func=forward_step,
                                           summary_writer=summary_writer)

        # Checkpointing at the end of each epoch.
        if args.save and (epoch + 1) % args.save_epoch == 0:
            save_checkpoint(args.iteration,
                            model,
                            optimizer,
                            lr_scheduler,
                            args,
                            only_changed_parameters=True)

        # Callback at the end of each epoch.
        if end_of_epoch_callback is not None and (epoch +
                                                  1) % args.eval_epoch == 0:
            score_dict = end_of_epoch_callback(model,
                                               epoch,
                                               summary_writer=summary_writer)
            validation_metric = args.validation_metric if args.validation_metric else list(
                score_dict.keys())[0]
            validation_score = score_dict[validation_metric]
            if best_iteration is None or validation_score > best_score:
                best_iteration = args.iteration
                best_score = validation_score
                print_rank_0(
                    f"Found best {validation_metric} {best_score} at {best_iteration}"
                )
                save_checkpoint(args.iteration,
                                model,
                                optimizer,
                                lr_scheduler,
                                args,
                                tag="best",
                                barrier=False,
                                only_changed_parameters=True,
                                no_deepspeed=True,
                                no_save_optim=True)
                if torch.distributed.get_rank() == 0:
                    score_dict.update({"type": "validation", "epoch": epoch})
                    with open(os.path.join(args.log_dir, "results.json"),
                              "w") as output:
                        output.write(json.dumps(score_dict) + "\n")
                    with open(
                            os.path.join(args.save,
                                         "best_checkpointed_iteration.txt"),
                            "w") as output:
                        output.write(str(best_iteration))
    torch.distributed.barrier()
    return best_iteration