示例#1
0
def train(args, model, train_loader, test_loader, device, logger, optimizers):

    start_time = time.time()
    test_loader_iter = iter(test_loader)
    current_D_steps, train_generator = 0, True
    best_loss, best_test_loss, mean_epoch_loss = np.inf, np.inf, np.inf
    train_writer = SummaryWriter(os.path.join(args.tensorboard_runs, 'train'))
    test_writer = SummaryWriter(os.path.join(args.tensorboard_runs, 'test'))
    storage, storage_test = model.storage_train, model.storage_test

    amortization_opt, hyperlatent_likelihood_opt = optimizers['amort'], optimizers['hyper']
    if model.use_discriminator is True:
        disc_opt = optimizers['disc']


    for epoch in trange(args.n_epochs, desc='Epoch'):

        epoch_loss, epoch_test_loss = [], []
        epoch_start_time = time.time()

        if epoch > 0:
            ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger)

        model.train()

        for idx, (data, bpp) in enumerate(tqdm(train_loader, desc='Train'), 0):

            data = data.to(device, dtype=torch.float)

            try:
                if model.use_discriminator is True:
                    # Train D for D_steps, then G, using distinct batches
                    losses = model(data, train_generator=train_generator)
                    compression_loss = losses['compression']
                    disc_loss = losses['disc']

                    if train_generator is True:
                        optimize_compression_loss(compression_loss, amortization_opt, hyperlatent_likelihood_opt)
                        train_generator = False
                    else:
                        optimize_loss(disc_loss, disc_opt)
                        current_D_steps += 1

                        if current_D_steps == args.discriminator_steps:
                            current_D_steps = 0
                            train_generator = True

                        continue
                else:
                    # Rate, distortion, perceptual only
                    losses = model(data, train_generator=True)
                    compression_loss = losses['compression']
                    optimize_compression_loss(compression_loss, amortization_opt, hyperlatent_likelihood_opt)

            except KeyboardInterrupt:
                # Note: saving not guaranteed!
                if model.step_counter > args.log_interval+1:
                    logger.warning('Exiting, saving ...')
                    ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger)
                    return model, ckpt_path
                else:
                    return model, None

            if model.step_counter % args.log_interval == 1:
                epoch_loss.append(compression_loss.item())
                mean_epoch_loss = np.mean(epoch_loss)

                best_loss = utils.log(model, storage, epoch, idx, mean_epoch_loss, compression_loss.item(),
                                best_loss, start_time, epoch_start_time, batch_size=data.shape[0],
                                avg_bpp=bpp.mean().item(), logger=logger, writer=train_writer)
                try:
                    test_data, test_bpp = test_loader_iter.next()
                except StopIteration:
                    test_loader_iter = iter(test_loader)
                    test_data, test_bpp = test_loader_iter.next()

                best_test_loss, epoch_test_loss = test(args, model, epoch, idx, data, test_data, test_bpp, device, epoch_test_loss, storage_test,
                     best_test_loss, start_time, epoch_start_time, logger, train_writer, test_writer)

                with open(os.path.join(args.storage_save, 'storage_{}_tmp.pkl'.format(args.name)), 'wb') as handle:
                    pickle.dump(storage, handle, protocol=pickle.HIGHEST_PROTOCOL)

                model.train()

                # LR scheduling
                utils.update_lr(args, amortization_opt, model.step_counter, logger)
                utils.update_lr(args, hyperlatent_likelihood_opt, model.step_counter, logger)
                if model.use_discriminator is True:
                    utils.update_lr(args, disc_opt, model.step_counter, logger)

                if model.step_counter > args.n_steps:
                    logger.info('Reached step limit [args.n_steps = {}]'.format(args.n_steps))
                    break

            if (idx % args.save_interval == 1) and (idx > args.save_interval):
                ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger)

        # End epoch
        mean_epoch_loss = np.mean(epoch_loss)
        mean_epoch_test_loss = np.mean(epoch_test_loss)

        logger.info('===>> Epoch {} | Mean train loss: {:.3f} | Mean test loss: {:.3f}'.format(epoch,
            mean_epoch_loss, mean_epoch_test_loss))

        if model.step_counter > args.n_steps:
            break

    with open(os.path.join(args.storage_save, 'storage_{}_{:%Y_%m_%d_%H:%M:%S}.pkl'.format(args.name, datetime.datetime.now())), 'wb') as handle:
        pickle.dump(storage, handle, protocol=pickle.HIGHEST_PROTOCOL)

    ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger)
    args.ckpt = ckpt_path
    logger.info("Training complete. Time elapsed: {:.3f} s. Number of steps: {}".format((time.time()-start_time), model.step_counter))

    return model, ckpt_path
示例#2
0
def train(args, model, train_loader, test_loader, device, logger, optimizers, bpp):

    start_time = time.time()
    test_loader_iter = iter(test_loader)
    current_D_steps, train_generator = 0, True
    best_loss, best_test_loss, mean_epoch_loss = np.inf, np.inf, np.inf
    train_writer = SummaryWriter(os.path.join(args.tensorboard_runs, 'train'))
    test_writer = SummaryWriter(os.path.join(args.tensorboard_runs, 'test'))
    storage, storage_test = model.storage_train, model.storage_test

    Ntrain = len(train_loader)
    classi_loss_total_train, classi_acc_total_train = torch.Tensor(Ntrain), torch.Tensor(Ntrain)

    classi_opt, amortization_opt, hyperlatent_likelihood_opt = optimizers['classi'], optimizers['amort'], optimizers['hyper']
    #end_of_epoch_metrics(args, model, train_loader, device, logger)
    if model.use_discriminator is True:
        disc_opt = optimizers['disc']

    best_mean_test_classi_loss_total = 10000000000
    for epoch in trange(args.n_epochs, desc='Epoch'):

        epoch_loss, epoch_test_loss = [], []
        epoch_start_time = time.time()

        if epoch > 0:
            ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger)

        model.train()
        test_index = 0
        test_acc_total= 0
        mean_test_acc_total = 0
        test_classi_loss_total = 0

        for idx, (data, y) in enumerate(tqdm(train_loader, desc='Train'), 0):

            #if idx == 10:
            #    break
            data = data.to(device, dtype=torch.float)
            y = y.to(device)
            try:
                if model.use_classiOnly is True:
                    losses = model(data, y, train_generator=False)
                    classi_loss = losses['classi']
                    classi_acc  = losses['classi_acc']
                    compression_loss = losses['compression']

                    optimize_loss(classi_loss, classi_opt)
                    classi_loss_total_train[idx] = classi_loss.data
                    classi_acc_total_train[idx] = classi_acc.data
                    model.step_counter += 1
                else:
                  if model.use_discriminator is True:
                    # Train D for D_steps, then G, using distinct batches
                    losses = model(data, y, train_generator=train_generator)
                    compression_loss = losses['compression']
                    disc_loss = losses['disc']

                    if train_generator is True:
                        optimize_compression_loss(compression_loss, amortization_opt, hyperlatent_likelihood_opt)
                        train_generator = False
                    else:
                        optimize_loss(disc_loss, disc_opt)
                        current_D_steps += 1

                        if current_D_steps == args.discriminator_steps:
                            current_D_steps = 0
                            train_generator = True

                        continue
                  else:
                    # Rate, distortion, perceptual only
                    losses = model(data, y, train_generator=True)
                    compression_loss = losses['compression']
                    optimize_compression_loss(compression_loss, amortization_opt, hyperlatent_likelihood_opt)

            except KeyboardInterrupt:
                # Note: saving not guaranteed!
                if model.step_counter > args.log_interval+1:
                    logger.warning('Exiting, saving ...')
                    ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger)
                    return model, ckpt_path
                else:
                    return model, None

            if model.step_counter % args.log_interval == 0:
                epoch_loss.append(compression_loss.item())
                mean_epoch_loss = np.mean(epoch_loss)

                #best_loss = utils.log(model, storage, epoch, idx, mean_epoch_loss, compression_loss.item(),
                #                best_loss, start_time, epoch_start_time, batch_size=data.shape[0],
                #                avg_bpp=bpp, logger=logger, writer=train_writer)
                try:
                    test_data, ytest = test_loader_iter.next()

                except StopIteration:
                    test_loader_iter = iter(test_loader)
                    test_data, ytest = test_loader_iter.next()

                ytest = ytest.to(device)
                best_test_loss, epoch_test_loss, mean_test_acc, mean_test_classi_loss = test(args, model, epoch, idx, data, y, test_data, ytest, device, epoch_test_loss, storage_test,
                     best_test_loss, start_time, epoch_start_time, logger, train_writer, test_writer)

                test_index = test_index + 1
                test_classi_loss_total = test_classi_loss_total + mean_test_classi_loss
                mean_test_classi_loss_total = test_classi_loss_total/test_index

                test_acc_total = test_acc_total  + mean_test_acc
                mean_test_acc_total = test_acc_total/test_index

                with open(os.path.join(args.storage_save, 'storage_{}_tmp.pkl'.format(args.name)), 'wb') as handle:
                    pickle.dump(storage, handle, protocol=pickle.HIGHEST_PROTOCOL)

                model.train()



                if model.step_counter > args.n_steps:
                    logger.info('Reached step limit [args.n_steps = {}]'.format(args.n_steps))
                    break


            # LR scheduling
        if model.use_classiOnly is True:
                utils.update_lr(args, classi_opt, model.step_counter, logger)
        utils.update_lr(args, amortization_opt, model.step_counter, logger)
        utils.update_lr(args, hyperlatent_likelihood_opt, model.step_counter, logger)
        if model.use_discriminator is True:
                utils.update_lr(args, disc_opt, model.step_counter, logger)
        if mean_test_classi_loss_total < best_mean_test_classi_loss_total:
            logger.info(f'Classi_loss decreased to : {mean_test_classi_loss:.3f}.  Saving Model')
            best_mean_test_classi_loss_total = mean_test_classi_loss_total
            ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger)
        # End epoch
        mean_epoch_loss = np.mean(epoch_loss)
        mean_epoch_test_loss = np.mean(epoch_test_loss)

        logger.info('===>> Epoch {} | Mean train loss: {:.3f} | Mean test loss: {:.3f}  | Mean test classi acc: {:.3f}'.format(epoch,
            mean_epoch_loss, mean_epoch_test_loss, mean_test_acc_total))
        logger.info(f'ClassiLossTrain: mean={classi_loss_total_train.mean(dim=0):.3f}, std={classi_loss_total_train.std(dim=0):.3f}')
        logger.info(f'ClassiAccTrain: mean={classi_acc_total_train.mean(dim=0):.3f}, std={classi_acc_total_train.std(dim=0):.3f}')

        #end_of_epoch_metrics(args, model, train_loader, device, logger)
        #end_of_epoch_metrics(args, model, test_loader, device, logger)


        if model.step_counter > args.n_steps:
            break

    with open(os.path.join(args.storage_save, 'storage_{}_{:%Y_%m_%d_%H:%M:%S}.pkl'.format(args.name, datetime.datetime.now())), 'wb') as handle:
        pickle.dump(storage, handle, protocol=pickle.HIGHEST_PROTOCOL)

    ckpt_path = utils.save_model(model, optimizers, mean_epoch_loss, epoch, device, args=args, logger=logger)
    args.ckpt = ckpt_path
    logger.info("Training complete. Time elapsed: {:.3f} s. Number of steps: {}".format((time.time()-start_time), model.step_counter))

    return model, ckpt_path