def train_eval(self):

        for it in tqdm(range(self.st_iter, self.ed_iter + 1), total=self.ed_iter - self.st_iter + 1,
                       leave=False, dynamic_ncols=True):
            self.model.train()
            self.train_iter(it)

            if it % self.iter_save == 0:
                self.model.eval()
                self.eval(it)

                self.metric = self.eval_meter.average().silog
                train_avg = self.train_meter.average()
                eval_avg = self.eval_meter.average()

                self.logger.add_scalars('TrainVal/rmse',
                                        {'train_rmse': train_avg.rmse, 'test_rmse': eval_avg.rmse}, it)
                self.logger.add_scalars('TrainVal/rel',
                                        {'train_rel': train_avg.absrel, 'test_rmse': eval_avg.absrel}, it)
                self.logger.add_scalars('TrainVal/lg10',
                                        {'train_lg10': train_avg.lg10, 'test_rmse': eval_avg.lg10}, it)
                self.logger.add_scalars('TrainVal/Delta1',
                                        {'train_d1': train_avg.delta1, 'test_d1': eval_avg.delta1}, it)
                self.logger.add_scalars('TrainVal/Delta2',
                                        {'train_d2': train_avg.delta2, 'test_d2': eval_avg.delta2}, it)
                self.logger.add_scalars('TrainVal/Delta3',
                                        {'train_d3': train_avg.delta3, 'test_d3': eval_avg.delta3}, it)
                self.train_meter.reset()

                # remember best rmse and save checkpoint
                is_best = eval_avg.absrel < self.best_result.absrel
                if is_best:
                    self.best_result = eval_avg
                    with open(self.best_txt, 'w') as txtfile:
                        txtfile.write(
                            "Iter={}, rmse={:.3f}, rel={:.3f}, log10={:.3f}, d1={:.3f}, d2={:.3f}, dd31={:.3f}, "
                            "t_gpu={:.4f}".format(it, eval_avg.rmse, eval_avg.absrel, eval_avg.lg10,
                                                  eval_avg.delta1, eval_avg.delta2, eval_avg.delta3, eval_avg.gpu_time))

                # save checkpoint for each epoch
                utils.save_checkpoint({
                    'args': self.opt,
                    'epoch': it,
                    'state_dict': self.model.state_dict(),
                    'best_result': self.best_result,
                    'optimizer': self.optimizer,
                }, is_best, it, self.output_directory)

            # Update learning rate
            do_schedule(self.opt, self.scheduler, it=it, len=self.iter_save, metrics=self.metric)

            # record the change of learning_rate
            for i, param_group in enumerate(self.optimizer.param_groups):
                old_lr = float(param_group['lr'])
                self.logger.add_scalar('Lr/lr_' + str(i), old_lr, it)

        self.logger.close()
Example #2
0
  def train_comparative(self, student, export_only=False):
    """
      Trains the student using a comparative loss function (Mean Squared Error)
      based on the output of Teacher.
      Args:
        student: Keras model of the student.
    """
    train_args = self.train_args["comparative"]
    total_steps = train_args["num_steps"]
    decay_rate = train_args["decay_rate"]
    decay_steps = train_args["decay_steps"]
    optimizer = tf.optimizers.Adam()
    checkpoint = tf.train.Checkpoint(
        student_generator=student,
        student_optimizer=optimizer)
    status = utils.load_checkpoint(
        checkpoint,
        "comparative_checkpoint",
        basepath=self.model_dir,
        use_student_settings=True)
    if export_only:
      return
    loss_fn = tf.keras.losses.MeanSquaredError(reduction="none")
    metric_fn = tf.keras.metrics.Mean()
    student_psnr = tf.keras.metrics.Mean()
    teacher_psnr = tf.keras.metrics.Mean()

    def step_fn(image_lr, image_hr):
      """
        Function to be replicated among the worker nodes
        Args:
          image_lr: Distributed Batch of Low Resolution Images
          image_hr: Distributed Batch of High Resolution Images
      """
      with tf.GradientTape() as tape:
        teacher_fake = self.teacher_generator.unsigned_call(image_lr)
        student_fake = student.unsigned_call(image_lr)
        student_fake = tf.clip_by_value(student_fake, 0, 255)
        teacher_fake = tf.clip_by_value(teacher_fake, 0, 255)
        image_hr = tf.clip_by_value(image_hr, 0, 255)
        student_psnr(tf.reduce_mean(tf.image.psnr(student_fake, image_hr, max_val=256.0)))
        teacher_psnr(tf.reduce_mean(tf.image.psnr(teacher_fake, image_hr, max_val=256.0)))
        loss = utils.pixelwise_mse(teacher_fake, student_fake)
        loss = tf.reduce_mean(loss) * (1.0 / self.batch_size)
        metric_fn(loss)
      student_vars = list(set(student.trainable_variables))
      gradient = tape.gradient(loss, student_vars)
      train_op = optimizer.apply_gradients(
          zip(gradient, student_vars))
      with tf.control_dependencies([train_op]):
        return tf.cast(optimizer.iterations, tf.float32)

    @tf.function
    def train_step(image_lr, image_hr):
      """
        In Graph Function to assign trainer function to
        replicate among worker nodes.
        Args:
          image_lr: Distributed batch of Low Resolution Images
          image_hr: Distributed batch of High Resolution Images
      """
      distributed_metric = self.strategy.experimental_run_v2(
          step_fn, args=(image_lr, image_hr))
      mean_metric = self.strategy.reduce(
          tf.distribute.ReduceOp.MEAN, distributed_metric, axis=None)
      return mean_metric
    logging.info("Starting comparative loss training")
    while True:
      image_lr, image_hr = next(self.dataset)
      step = train_step(image_lr, image_hr)
      if step >= total_steps:
        return
      for _step in decay_steps.copy():
        if step >= _step:
          decay_steps.pop(0)
          logging.debug("Reducing Learning Rate by: %f" % decay_rate)
          optimizer.learning_rate.assign(optimizer.learning_rate * decay_rate)
      if status:
        status.assert_consumed()
        logging.info("Checkpoint loaded successfully")
        status = None
      # Writing Summary
      with self.summary_writer.as_default():
        tf.summary.scalar("student_loss", metric_fn.result(), step=optimizer.iterations)
        tf.summary.scalar("mean_psnr", student_psnr.result(), step=optimizer.iterations)
      if self.summary_writer_2:
        with self.summary_writer_2.as_default():
          tf.summary.scalar("mean_psnr", teacher_psnr.result(), step=optimizer.iterations)

      if not step % train_args["print_step"]:
        logging.info("[COMPARATIVE LOSS] Step: %s\tLoss: %s" %
                     (step, metric_fn.result()))
      # Saving Checkpoint
      if not step % train_args["checkpoint_step"]:
        utils.save_checkpoint(
            checkpoint,
            "comparative_checkpoint",
            basepath=self.model_dir,
            use_student_settings=True)
Example #3
0
  def train_adversarial(self, student, export_only=False):
    """
      Train the student adversarially using a joint loss between teacher discriminator
      and mean squared error between the output of the student-teacher generator pair.
      Args:
        student: Keras model of the student to train.
    """
    train_args = self.train_args["adversarial"]
    total_steps = train_args["num_steps"]
    decay_steps = train_args["decay_steps"]
    decay_rate = train_args["decay_rate"]

    lambda_ = train_args["lambda"]
    alpha = train_args["alpha"]

    generator_metric = tf.keras.metrics.Mean()
    discriminator_metric = tf.keras.metrics.Mean()
    generator_optimizer = tf.optimizers.Adam(learning_rate=train_args["initial_lr"])
    dummy_optimizer = tf.optimizers.Adam()
    discriminator_optimizer = tf.optimizers.Adam(learning_rate=train_args["initial_lr"])
    checkpoint = tf.train.Checkpoint(
        student_generator=student,
        student_optimizer=generator_optimizer,
        teacher_optimizer=discriminator_optimizer,
        teacher_generator=self.teacher_generator,
        teacher_discriminator=self.teacher_discriminator)

    status = None
    if not utils.checkpoint_exists(
        names="adversarial_checkpoint",
        basepath=self.model_dir,
        use_student_settings=True):
      if export_only:
        raise ValueError("Adversarial checkpoints not found")
    else:
      status = utils.load_checkpoint(
          checkpoint,
          "adversarial_checkpoint",
          basepath=self.model_dir,
        use_student_settings=True)
    if export_only:
      if not status:
        raise ValueError("No Checkpoint Loaded!")
      return
    ra_generator = utils.RelativisticAverageLoss(
        self.teacher_discriminator, type_="G")
    ra_discriminator = utils.RelativisticAverageLoss(
        self.teacher_discriminator, type_="D")
    perceptual_loss = utils.PerceptualLoss(
        weights="imagenet",
        input_shape=self.hr_size,
        loss_type="L2")
    student_psnr = tf.keras.metrics.Mean()
    teacher_psnr = tf.keras.metrics.Mean()

    def step_fn(image_lr, image_hr):
      """
        Function to be replicated among the worker nodes
        Args:
          image_lr: Distributed Batch of Low Resolution Images
          image_hr: Distributed Batch of High Resolution Images
      """
      with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        teacher_fake = self.teacher_generator.unsigned_call(image_lr)
        logging.debug("Fetched Fake: Teacher")
        teacher_fake = tf.clip_by_value(teacher_fake, 0, 255)
        student_fake = student.unsigned_call(image_lr)
        logging.debug("Fetched Fake: Student")
        student_fake = tf.clip_by_value(student_fake, 0, 255)
        image_hr = tf.clip_by_value(image_hr, 0, 255)
        psnr = tf.image.psnr(student_fake, image_hr, max_val=255.0)
        student_psnr(tf.reduce_mean(psnr))
        psnr = tf.image.psnr(teacher_fake, image_hr, max_val=255.0)
        teacher_psnr(tf.reduce_mean(psnr))
        mse_loss = utils.pixelwise_mse(teacher_fake, student_fake)

        image_lr = utils.preprocess_input(image_lr)
        image_hr = utils.preprocess_input(image_hr)
        student_fake = utils.preprocess_input(student_fake)
        teacher_fake = utils.preprocess_input(teacher_fake)

        student_ra_loss = ra_generator(teacher_fake, student_fake)
        logging.debug("Relativistic Average Loss: Student")
        discriminator_loss = ra_discriminator(teacher_fake, student_fake)
        discriminator_metric(discriminator_loss)
        discriminator_loss = tf.reduce_mean(
            discriminator_loss) * (1.0 / self.batch_size)
        logging.debug("Relativistic Average Loss: Teacher")
        percep_loss = perceptual_loss(image_hr, student_fake)
        generator_loss = lambda_ * percep_loss + alpha * student_ra_loss + (1 - alpha) * mse_loss
        generator_metric(generator_loss)
        logging.debug("Calculated Joint Loss for Generator")
        generator_loss = tf.reduce_mean(
            generator_loss) * (1.0 / self.batch_size)
      generator_gradient = gen_tape.gradient(
          generator_loss, student.trainable_variables)
      logging.debug("calculated gradient: generator")
      discriminator_gradient = disc_tape.gradient(
          discriminator_loss, self.teacher_discriminator.trainable_variables)
      logging.debug("calculated gradient: discriminator")
      generator_op = generator_optimizer.apply_gradients(
          zip(generator_gradient, student.trainable_variables))
      logging.debug("applied generator gradients")
      discriminator_op = discriminator_optimizer.apply_gradients(
          zip(discriminator_gradient, self.teacher_discriminator.trainable_variables))
      logging.debug("applied discriminator gradients")

      with tf.control_dependencies(
              [generator_op, discriminator_op]):
        return tf.cast(discriminator_optimizer.iterations, tf.float32)

    @tf.function
    def train_step(image_lr, image_hr):
      """
        In Graph Function to assign trainer function to
        replicate among worker nodes.
        Args:
          image_lr: Distributed batch of Low Resolution Images
          image_hr: Distributed batch of High Resolution Images
      """
      distributed_metric = self.strategy.experimental_run_v2(
          step_fn,
          args=(image_lr, image_hr))
      mean_metric = self.strategy.reduce(
          tf.distribute.ReduceOp.MEAN,
          distributed_metric, axis=None)
      return mean_metric

    logging.info("Starting Adversarial Training")

    while True:
      image_lr, image_hr = next(self.dataset)
      step = train_step(image_lr, image_hr)
      if status:
        status.assert_consumed()
        logging.info("Loaded Checkpoint successfully!")
        status = None
      if not isinstance(decay_steps, list):
        if not step % decay_steps:
          logging.debug("Decaying Learning Rate by: %s" % decay_rate)
          generator_optimizer.learning_rate.assign(
              generator_optimizer.learning_rate * decay_rate)
          discriminator_optimizer.learning_rate.assign(
              discriminator_optimizer.learning_rate * decay_rate)
      else:
        for decay_step in decay_steps.copy():
          if decay_step <= step:
            decay_steps.pop(0)
            logging.debug("Decaying Learning Rate by: %s" % decay_rate)
            generator_optimizer.learning_rate.assign(
                generator_optimizer.learning_rate * decay_rate)
            discriminator_optimizer.learning_rate.assign(
                discriminator_optimizer.learning_rate * decay_rate)
      # Setting Up Logging
      with self.summary_writer.as_default():
        tf.summary.scalar(
            "student_loss",
            generator_metric.result(),
            step=discriminator_optimizer.iterations)
        tf.summary.scalar(
            "teacher_discriminator_loss",
            discriminator_metric.result(),
            step=discriminator_optimizer.iterations)
        tf.summary.scalar(
            "mean_psnr",
            student_psnr.result(),
            step=discriminator_optimizer.iterations)
      if self.summary_writer_2:
        with self.summary_writer_2.as_default():
          tf.summary.scalar(
              "mean_psnr",
              teacher_psnr.result(),
              step=discriminator_optimizer.iterations)
      if not step % train_args["print_step"]:
        logging.info(
            "[ADVERSARIAL] Step: %s\tStudent Loss: %s\t"
            "Discriminator Loss: %s" %
            (step, generator_metric.result(),
             discriminator_metric.result()))
      # Setting Up Checkpoint
      if not step % train_args["checkpoint_step"]:
        utils.save_checkpoint(
            checkpoint,
            "adversarial_checkpoint",
            basepath=self.model_dir,
            use_student_settings=True)
      if step >= total_steps:
        return
def main(args):
    ##############################################################################
    """Setup parameters"""
    # parse args
    best_acc1 = 0.0
    args.start_epoch = 0
    if os.path.exists(args.config):
        config = load_config(args.config)
    else:
        raise ValueError("Config file does not exist.")
    if args.local_rank == 0:
        print("Current configurations:")
        pprint(config)

    # prep for output folder
    config_filename = os.path.basename(args.config).replace('.json', '')
    ckpt_folder = os.path.join('./ckpt', config_filename)
    if (args.local_rank == 0) and not os.path.exists(ckpt_folder):
        os.mkdir(ckpt_folder)
    # tensorboard writer
    global writer
    if args.local_rank == 0:
        writer = SummaryWriter(os.path.join(ckpt_folder, 'logs'))

    # use spawn for mp, this will fix a deadlock by OpenCV
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')

    # for distributed training
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = (int(os.environ['WORLD_SIZE']) > 1)
    args.world_size = 1
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        print(
            "Distributed training (local rank {:d} / world size {:d})".format(
                args.local_rank, args.world_size))

    # fix the random seeds (the best we can)
    fixed_random_seed = 2019 + int(args.distributed) * args.local_rank
    torch.manual_seed(fixed_random_seed)
    np.random.seed(fixed_random_seed)
    random.seed(fixed_random_seed)

    # skip weight loading if resume from a checkpoint
    if args.resume:
        config['network']['pretrained'] = None

    # re-scale learning rate based on world size
    if args.distributed:
        config['optimizer']["learning_rate"] *= args.world_size
    else:
        # also need to re-scale the worker number if using data parallel
        config['optimizer']["learning_rate"] *= len(
            config['network']['devices'])
        config['input']['num_workers'] *= len(config['network']['devices'])

    ##############################################################################
    """Create datasets"""
    # set up transforms and dataset
    train_transforms, val_transforms, _ = \
      create_video_transforms_joint(config['input'])
    train_dataset, val_dataset = create_video_dataset_joint(
        config['dataset'], train_transforms, val_transforms)
    is_train, is_test = (train_dataset is not None), (val_dataset is not None)

    # print the data augs
    if args.local_rank == 0:
        print("Training time data augmentations:")
        pprint(train_transforms)
        print("Testing time data augmentations:")
        pprint(val_transforms)

    if is_train:
        # only instantiate the dataset if necessary
        train_dataset = train_dataset()
        train_dataset.load()

    if is_test:
        # only instantiate the dataset if necessary
        val_dataset = val_dataset()
        val_dataset.load()

    # reset loss params
    if config['network']['balanced_beta'] > 0:
        num_samples_per_cls = train_dataset.get_num_samples_per_cls()
        config['network']['cls_weights'] = get_cls_weights(
            num_samples_per_cls, config['network']['balanced_beta'])
        if args.local_rank == 0:
            print("Using class balanced loss with beta = {:0.4f}".format(
                config['network']['balanced_beta']))

    ##############################################################################
    """Create model (w. loss) & optimizer"""
    # create model -> GPU 0
    model = ModelBuilder(config['network'])
    if args.sync_bn and args.distributed:
        # discrepancy in docs (this now works as default for pytorch 1.2)
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        # this will force the model to re-freeze the params (bug fixed in 1.2)
        model.train()
    model = model.cuda()

    # create optimizer
    optimizer = create_optim(model, config['optimizer'])

    # create the model
    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)
    else:
        model = nn.DataParallel(model, device_ids=config['network']['devices'])

    ##############################################################################
    """Create data loaders / scheduler"""
    if is_train:
        train_sampler = None
        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=config['input']['batch_size'],
            num_workers=config['input']['num_workers'],
            collate_fn=fast_clip_collate_joint,
            shuffle=(train_sampler is None),
            pin_memory=True,
            sampler=train_sampler,
            drop_last=True)

    if is_test:
        val_sampler = None
        val_batch_size = max(
            1, config['input']['batch_size'] // val_dataset.get_num_clips())
        if args.distributed:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_dataset)
        # validation here is not going to be accurate any way ...
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=val_batch_size,
            num_workers=config['input']['num_workers'],
            collate_fn=fast_clip_collate_joint,
            shuffle=False,
            pin_memory=True,
            sampler=val_sampler,
            drop_last=True)

    # set up learning rate scheduler
    if is_train:
        num_iters_per_epoch = len(train_loader)
        scheduler = create_scheduler(optimizer,
                                     config['optimizer']['schedule'],
                                     config['optimizer']['epochs'],
                                     num_iters_per_epoch)

    ##############################################################################
    """Resume from model / Misc"""
    # resume from a checkpoint?
    if args.resume:
        if os.path.isfile(args.resume):
            if args.local_rank == 0:
                print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume,
                                    map_location=lambda storage, loc: storage.
                                    cuda(args.local_rank))
            if not args.fix_res:
                args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            # only load the optimizer if necessary
            if is_train and (not args.fix_res):
                optimizer.load_state_dict(checkpoint['optimizer'])
                scheduler.load_state_dict(checkpoint['scheduler'])
            if args.local_rank == 0:
                print("=> loaded checkpoint '{}' (epoch {}, acc1 {})".format(
                    args.resume, checkpoint['epoch'], best_acc1))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            return

    # training: enable cudnn benchmark
    cudnn.enabled = True
    cudnn.benchmark = True
    # model architecture
    model_arch = "{:s}-{:s}".format(config['network']['backbone'],
                                    config['network']['decoder'])

    ##############################################################################
    """Training / Validation"""
    # start the training
    if is_train:
        # start the training
        if args.local_rank == 0:
            # save the current config
            with open(os.path.join(ckpt_folder, 'config.text'), 'w') as fid:
                pprint(config, stream=fid)
            print("Training model {:s} ...".format(model_arch))
            pprint(model)

        for epoch in range(args.start_epoch, config['optimizer']['epochs']):
            if args.distributed:
                train_sampler.set_epoch(epoch)
            # acc1, acc5 = validate(val_loader, model, epoch, args, config)
            # train for one epoch
            train(train_loader, model, optimizer, scheduler, epoch, args,
                  config)

            # evaluate on validation set once in a while
            # test on every epoch at the end of training
            # Note this will also run after first epoch (make sure training is on track)
            # print(epoch)
            if (epoch % args.valid_freq == 0) \
                or (epoch > 0.6 * config['optimizer']['epochs']):
                # use prec bn to aggregate stats before validation
                if args.prec_bn:
                    prec_bn(train_loader, model, epoch, args, config)
                acc1, acc5 = validate(val_loader, model, epoch, args, config)

                if args.local_rank == 0:
                    # remember best acc@1 and save checkpoint
                    is_best = acc1 > best_acc1
                    best_acc1 = max(acc1, best_acc1)
                    save_checkpoint(
                        {
                            'epoch': epoch + 1,
                            'model_arch': model_arch,
                            'state_dict': model.state_dict(),
                            'best_acc1': acc1,
                            'scheduler': scheduler.state_dict(),
                            'optimizer': optimizer.state_dict(),
                        },
                        is_best,
                        file_folder=ckpt_folder)

            # sync all processes manually
            if args.distributed:
                sync_processes()
            else:
                torch.cuda.empty_cache()

    if args.local_rank == 0:
        writer.close()
        print("All done!")
Example #5
0
def main():
    global args, best_result, output_directory

    # set random seed
    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed(args.manual_seed)
    np.random.seed(args.manual_seed)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        args.batch_size = args.batch_size * torch.cuda.device_count()
    else:
        print("Let's use GPU ", torch.cuda.current_device())

    train_loader, val_loader = create_loader(args)

    if args.mode == 'test':
        if args.resume:
            assert os.path.isfile(args.resume), \
                "=> no checkpoint found at '{}'".format(args.resume)
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            epoch = checkpoint['epoch']
            best_result = checkpoint['best_result']

            # solve 'out of memory'
            model = checkpoint['model']

            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

            # clear memory
            del checkpoint
            # del model_dict
            torch.cuda.empty_cache()
        else:
            print("no trained model to test.")

        result, img_merge = validate(args,
                                     val_loader,
                                     model,
                                     epoch,
                                     logger=None)

        print(
            'Test Result: mean iou={result.mean_iou:.3f}, mean acc={result.mean_acc:.3f}.'
            .format(result=result))
    elif args.mode == 'train':
        if args.resume:
            assert os.path.isfile(args.resume), \
                "=> no checkpoint found at '{}'".format(args.resume)
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            start_iter = checkpoint['epoch'] + 1
            best_result = checkpoint['best_result']
            optimizer = checkpoint['optimizer']

            # solve 'out of memory'
            model = checkpoint['model']

            print("=> loaded checkpoint (epoch {})".format(
                checkpoint['epoch']))

            # clear memory
            del checkpoint
            # del model_dict
            torch.cuda.empty_cache()
        else:
            print("=> creating Model")
            model = get_models(args)
            print("=> model created.")
            start_iter = 1

            # different modules have different learning rate
            train_params = [{
                'params': model.get_1x_lr_params(),
                'lr': args.lr
            }, {
                'params': model.get_10x_lr_params(),
                'lr': args.lr * 10
            }]

            print(train_params)

            optimizer = torch.optim.SGD(train_params,
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)

            # You can use DataParallel() whether you use Multi-GPUs or not
            model = nn.DataParallel(model).cuda()

        scheduler = PolynomialLR(optimizer=optimizer,
                                 step_size=args.lr_decay,
                                 iter_max=args.max_iter,
                                 power=args.power)

        # loss function
        criterion = criteria._CrossEntropyLoss2d(size_average=True,
                                                 batch_average=True)

        # create directory path
        output_directory = utils.get_output_directory(args)
        if not os.path.exists(output_directory):
            os.makedirs(output_directory)
        best_txt = os.path.join(output_directory, 'best.txt')
        config_txt = os.path.join(output_directory, 'config.txt')

        # write training parameters to config file
        if not os.path.exists(config_txt):
            with open(config_txt, 'w') as txtfile:
                args_ = vars(args)
                args_str = ''
                for k, v in args_.items():
                    args_str = args_str + str(k) + ':' + str(v) + ',\t\n'
                txtfile.write(args_str)

        # create log
        log_path = os.path.join(
            output_directory, 'logs',
            datetime.now().strftime('%b%d_%H-%M-%S') + '_' +
            socket.gethostname())
        if os.path.isdir(log_path):
            shutil.rmtree(log_path)
        os.makedirs(log_path)
        logger = SummaryWriter(log_path)

        # train
        model.train()
        if args.freeze:
            model.module.freeze_backbone_bn()
        output_directory = utils.get_output_directory(args, check=True)

        average_meter = AverageMeter()

        for it in tqdm(range(start_iter, args.max_iter + 1),
                       total=args.max_iter,
                       leave=False,
                       dynamic_ncols=True):
            # for it in range(1, args.max_iter + 1):
            # Clear gradients (ready to accumulate)
            optimizer.zero_grad()

            loss = 0

            data_time = 0
            gpu_time = 0

            for _ in range(args.iter_size):
                end = time.time()
                try:
                    samples = next(loader_iter)
                except:
                    loader_iter = iter(train_loader)
                    samples = next(loader_iter)

                input = samples['image'].cuda()
                target = samples['label'].cuda()

                torch.cuda.synchronize()
                data_time_ = time.time()
                data_time += data_time_ - end

                with torch.autograd.detect_anomaly():
                    preds = model(input)  # @wx 注意输出

                    # print('#train preds size:', len(preds))
                    # print('#train preds[0] size:', preds[0].size())
                    iter_loss = 0
                    if args.msc:
                        for pred in preds:
                            # Resize labels for {100%, 75%, 50%, Max} logits
                            target_ = utils.resize_labels(
                                target,
                                shape=(pred.size()[-2], pred.size()[-1]))
                            # print('#train pred size:', pred.size())
                            iter_loss += criterion(pred, target_)
                    else:
                        pred = preds
                        target_ = utils.resize_labels(target,
                                                      shape=(pred.size()[-2],
                                                             pred.size()[-1]))
                        # print('#train pred size:', pred.size())
                        # print('#train target size:', target.size())
                        iter_loss += criterion(pred, target_)

                    # Backpropagate (just compute gradients wrt the loss)
                    iter_loss /= args.iter_size
                    iter_loss.backward()

                    loss += float(iter_loss)

                gpu_time += time.time() - data_time_

            torch.cuda.synchronize()

            # Update weights with accumulated gradients
            optimizer.step()

            # Update learning rate
            scheduler.step(epoch=it)

            # measure accuracy and record loss
            result = Result()
            pred = F.softmax(pred, dim=1)

            result.evaluate(pred.data.cpu().numpy(),
                            target.data.cpu().numpy(),
                            n_class=21)
            average_meter.update(result, gpu_time, data_time, input.size(0))

            if it % args.print_freq == 0:
                print('=> output: {}'.format(output_directory))
                print('Train Iter: [{0}/{1}]\t'
                      't_Data={data_time:.3f}({average.data_time:.3f}) '
                      't_GPU={gpu_time:.3f}({average.gpu_time:.3f})\n\t'
                      'Loss={Loss:.5f} '
                      'MeanAcc={result.mean_acc:.3f}({average.mean_acc:.3f}) '
                      'MIOU={result.mean_iou:.3f}({average.mean_iou:.3f}) '.
                      format(it,
                             args.max_iter,
                             data_time=data_time,
                             gpu_time=gpu_time,
                             Loss=loss,
                             result=result,
                             average=average_meter.average()))
                logger.add_scalar('Train/Loss', loss, it)
                logger.add_scalar('Train/mean_acc', result.mean_iou, it)
                logger.add_scalar('Train/mean_iou', result.mean_acc, it)

                for i, param_group in enumerate(optimizer.param_groups):
                    old_lr = float(param_group['lr'])
                    logger.add_scalar('Lr/lr_' + str(i), old_lr, it)

            if it % args.iter_save == 0:
                resu1t, img_merge = validate(args,
                                             val_loader,
                                             model,
                                             epoch=it,
                                             logger=logger)

                # remember best rmse and save checkpoint
                is_best = result.mean_iou < best_result.mean_iou
                if is_best:
                    best_result = result
                    with open(best_txt, 'w') as txtfile:
                        txtfile.write(
                            "Iter={}, mean_iou={:.3f}, mean_acc={:.3f}"
                            "t_gpu={:.4f}".format(it, result.mean_iou,
                                                  result.mean_acc,
                                                  result.gpu_time))
                    if img_merge is not None:
                        img_filename = output_directory + '/comparison_best.png'
                        utils.save_image(img_merge, img_filename)

                # save checkpoint for each epoch
                utils.save_checkpoint(
                    {
                        'args': args,
                        'epoch': it,
                        'model': model,
                        'best_result': best_result,
                        'optimizer': optimizer,
                    }, is_best, it, output_directory)

                # change to train mode
                model.train()
                if args.freeze:
                    model.module.freeze_backbone_bn()

        logger.close()
    else:
        print('no mode named as ', args.mode)
        exit(-1)