Exemplo n.º 1
0
    def configure_optimizers(self):
        """Summary
        Must be implemented
        Returns:
            TYPE: Description
        """
        optimizer = create_optimizer(self.cfg, self.model.parameters())

        if self.cfg.lr_scheduler == 'step':
            scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                  step_size=self.cfg.step_size,
                                                  gamma=self.cfg.lr_factor)
        elif self.cfg.lr_scheduler == 'cosin':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                             T_max=200,
                                                             eta_min=1e-6)
        elif self.cfg.lr_scheduler == 'cosin_epoch':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.cfg.tmax, eta_min=self.cfg.eta_min)
        elif self.cfg.lr_scheduler == 'onecycle':
            max_lr = [g["lr"] for g in optimizer.param_groups]
            scheduler = optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=max_lr,
                epochs=self.hparams.epochs,
                steps_per_epoch=len(self.train_dataloader()))
            scheduler = {"scheduler": scheduler, "interval": "step"}
        else:
            raise ValueError(
                'Does not support {} learning rate scheduler'.format(
                    self.cfg.lr_scheduler))
        return optimizer, scheduler
def run(args):

    setup_default_logging()
    #args = parser.parse_args()
    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1  #1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    torch.manual_seed(args.seed + args.rank)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         checkpoint_path=args.initial_checkpoint)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    # optionally resume from a checkpoint
    optimizer_state = None
    resume_epoch = None
    if args.resume:
        optimizer_state, resume_epoch = resume_checkpoint(model, args.resume)

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    optimizer = create_optimizer(args, model)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)

    use_amp = False
    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    model_ema = None
    if args.model_ema:
        # create EMA model after cuda()
        model_ema = ModelEma(model,
                             decay=args.model_ema_decay,
                             device='cpu' if args.model_ema_force_cpu else '',
                             resume=args.resume)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    # Load dataset
    data_dir = os.path.join(args.data, 'img')
    if not os.path.exists(data_dir):
        logging.error('Training folder does not exist at: {}'.format(data_dir))
        exit(1)
    dataset_train = MultiViewDataSet(train_file,
                                     class_file,
                                     data_dir,
                                     transform=transform_train)
    dataset_eval = MultiViewDataSet(test_file,
                                    class_file,
                                    data_dir,
                                    transform=transform_eval)

    loader_train = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)
    if 0:
        loader_train = create_loader(
            dataset_train,
            input_size=data_config['input_size'],
            batch_size=args.batch_size,
            is_training=True,
            use_prefetcher=args.prefetcher,
            rand_erase_prob=args.reprob,
            rand_erase_mode=args.remode,
            color_jitter=args.color_jitter,
            interpolation='random',
            mean=data_config['mean'],
            std=data_config['std'],
            num_workers=args.workers,
            distributed=args.distributed,
            collate_fn=collate_fn,
        )

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    metrics_history = OrderedDict()
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)

    try:
        for epoch in range(start_epoch, num_epochs):

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        model_ema=model_ema)

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)

            if model_ema is not None and not args.model_ema_force_cpu:
                ema_eval_metrics = validate(model_ema.ema,
                                            loader_eval,
                                            validate_loss_fn,
                                            args,
                                            log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                metrics_history[epoch] = eval_metrics
                make_plots(metrics_history, output_dir)

                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
Exemplo n.º 3
0
def main(args, config):

    if args.horovod:
        verbose = hvd.rank() == 0
        global_size = hvd.size()
        # global_rank = hvd.rank()
        local_rank = hvd.local_rank()
    else:
        verbose = True
        global_size = 1
        # global_rank = 0
        local_rank = 0

    timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.gmtime())
    logdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'runs',
                          args.architecture, timestamp)

    if verbose:
        writer = tf.summary.FileWriter(logdir=logdir)
        print("Arguments passed:")
        print(args)
        print(f"Saving files to {logdir}")

    else:
        writer = None

    final_shape = parse_tuple(args.final_shape)
    image_channels = final_shape[0]
    final_resolution = final_shape[-1]
    num_phases = int(np.log2(final_resolution) - 1)
    base_dim = num_filters(1, num_phases, size=args.network_size)

    var_list = list()
    global_step = 0

    for phase in range(1, num_phases + 1):

        tf.reset_default_graph()

        # ------------------------------------------------------------------------------------------#
        # DATASET

        size = 2 * 2**phase
        if args.dataset == 'imagenet':
            dataset = imagenet_dataset(
                args.dataset_path,
                args.scratch_path,
                size,
                copy_files=local_rank == 0,
                is_correct_phase=phase >= args.starting_phase,
                gpu=args.gpu,
                num_labels=1 if args.num_labels is None else args.num_labels)
        else:
            raise ValueError(f"Unknown dataset {args.dataset_path}")

        # Get DataLoader
        batch_size = max(1, args.base_batch_size // (2**(phase - 1)))

        if phase >= args.starting_phase:
            assert batch_size * global_size <= args.max_global_batch_size
            if verbose:
                print(
                    f"Using local batch size of {batch_size} and global batch size of {batch_size * global_size}"
                )

        if args.horovod:
            dataset.shard(hvd.size(), hvd.rank())

        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.repeat()
        dataset = dataset.prefetch(AUTOTUNE)
        dataset = dataset.make_one_shot_iterator()
        data = dataset.get_next()
        if len(data) == 1:
            real_image_input = data
            real_label = None
        elif len(data) == 2:
            real_image_input, real_label = data
        else:
            raise NotImplementedError()

        real_image_input = tf.ensure_shape(
            real_image_input, [batch_size, image_channels, size, size])
        real_image_input = real_image_input + tf.random.normal(
            tf.shape(real_image_input)) * .01

        if real_label is not None:
            real_label = tf.one_hot(real_label, depth=args.num_labels)

        # ------------------------------------------------------------------------------------------#
        # OPTIMIZERS

        g_lr = args.g_lr
        d_lr = args.d_lr

        if args.horovod:
            if args.g_scaling == 'sqrt':
                g_lr = g_lr * np.sqrt(hvd.size())
            elif args.g_scaling == 'linear':
                g_lr = g_lr * hvd.size()
            elif args.g_scaling == 'none':
                pass
            else:
                raise ValueError(args.g_scaling)

            if args.d_scaling == 'sqrt':
                d_lr = d_lr * np.sqrt(hvd.size())
            elif args.d_scaling == 'linear':
                d_lr = d_lr * hvd.size()
            elif args.d_scaling == 'none':
                pass
            else:
                raise ValueError(args.d_scaling)

        # d_lr = tf.Variable(d_lr, name='d_lr', dtype=tf.float32)
        # g_lr = tf.Variable(g_lr, name='g_lr', dtype=tf.float32)

        # # optimizer_gen = tf.train.AdamOptimizer(learning_rate=g_lr, beta1=args.beta1, beta2=args.beta2)
        # # optimizer_disc = tf.train.AdamOptimizer(learning_rate=d_lr, beta1=args.beta1, beta2=args.beta2)
        # # optimizer_gen = LAMB(learning_rate=g_lr, beta1=args.beta1, beta2=args.beta2)
        # # optimizer_disc = LAMB(learning_rate=d_lr, beta1=args.beta1, beta2=args.beta2)
        # # optimizer_gen = LARSOptimizer(learning_rate=g_lr, momentum=0, weight_decay=0)
        # # optimizer_disc = LARSOptimizer(learning_rate=d_lr, momentum=0, weight_decay=0)

        # # optimizer_gen = tf.train.RMSPropOptimizer(learning_rate=1e-3)
        # # optimizer_disc = tf.train.RMSPropOptimizer(learning_rate=1e-3)
        # # optimizer_gen = tf.train.GradientDescentOptimizer(learning_rate=1e-3)
        # # optimizer_disc = tf.train.GradientDescentOptimizer(learning_rate=1e-3)
        # # optimizer_gen = RAdamOptimizer(learning_rate=g_lr, beta1=args.beta1, beta2=args.beta2)
        # # optimizer_disc = RAdamOptimizer(learning_rate=d_lr, beta1=args.beta1, beta2=args.beta2)

        # lr_step = tf.Variable(0, name='step', dtype=tf.float32)
        # update_step = lr_step.assign_add(1.0)

        # with tf.control_dependencies([update_step]):
        #     update_g_lr = g_lr.assign(g_lr * args.g_annealing)
        #     update_d_lr = d_lr.assign(d_lr * args.d_annealing)

        # if args.horovod:
        #     if args.use_adasum:
        #         # optimizer_gen = hvd.DistributedOptimizer(optimizer_gen, op=hvd.Adasum)
        #         optimizer_gen = hvd.DistributedOptimizer(optimizer_gen)
        #         optimizer_disc = hvd.DistributedOptimizer(optimizer_disc, op=hvd.Adasum)
        #     else:
        #         optimizer_gen = hvd.DistributedOptimizer(optimizer_gen)
        #         optimizer_disc = hvd.DistributedOptimizer(optimizer_disc)

        # ------------------------------------------------------------------------------------------#
        # NETWORKS

        with tf.variable_scope('alpha'):
            alpha = tf.Variable(1, name='alpha', dtype=tf.float32)
            # Alpha init
            init_alpha = alpha.assign(1)

            # Specify alpha update op for mixing phase.
            num_steps = args.mixing_nimg // (batch_size * global_size)
            alpha_update = 1 / num_steps
            # noinspection PyTypeChecker
            update_alpha = alpha.assign(tf.maximum(alpha - alpha_update, 0))

        base_shape = [image_channels, 4, 4]

        if args.optim_strategy == 'simultaneous':
            gen_loss, disc_loss, gp_loss, gen_sample = forward_simultaneous(
                generator,
                discriminator,
                real_image_input,
                args.latent_dim,
                alpha,
                phase,
                num_phases,
                base_dim,
                base_shape,
                args.activation,
                args.leakiness,
                args.network_size,
                args.loss_fn,
                args.gp_weight,
                conditioning=real_label,
            )

            gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                         scope='generator')
            disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='discriminator')

            with tf.variable_scope('optimizer_gen'):
                # disc_loss = tf.Print(gen_loss, [gen_loss], 'g_loss')
                optimizer_gen = create_optimizer(
                    gen_loss,
                    gen_vars,
                    1e-8, (args.mixing_nimg + args.stabilizing_nimg) /
                    (batch_size * global_size),
                    8,
                    hvd=hvd,
                    optimizer_type='adam')

            with tf.variable_scope('optimizer_disc'):
                # disc_loss = tf.Print(disc_loss, [disc_loss], 'd_loss')
                optimizer_disc = create_optimizer(
                    disc_loss,
                    disc_vars,
                    1e-8, (args.mixing_nimg + args.stabilizing_nimg) /
                    (batch_size * global_size),
                    8,
                    hvd=hvd,
                    optimizer_type='lamb')

            # if args.horovod:
            #     if args.use_adasum:
            #         # optimizer_gen = hvd.DistributedOptimizer(optimizer_gen, op=hvd.Adasum)
            #         optimizer_gen = hvd.DistributedOptimizer(optimizer_gen, sparse_as_dense=True)
            #         optimizer_disc = hvd.DistributedOptimizer(optimizer_disc, op=hvd.Adasum, sparse_as_dense=True)
            #     else:
            #         optimizer_gen = hvd.DistributedOptimizer(optimizer_gen, sparse_as_dense=True)
            #         optimizer_disc = hvd.DistributedOptimizer(optimizer_disc, sparse_as_dense=True)

            # g_gradients = optimizer_gen.compute_gradients(gen_loss, var_list=gen_vars)
            # d_gradients = optimizer_disc.compute_gradients(disc_loss, var_list=disc_vars)

            # g_norms = tf.stack([tf.norm(grad) for grad, var in g_gradients if grad is not None])
            # max_g_norm = tf.reduce_max(g_norms)
            # d_norms = tf.stack([tf.norm(grad) for grad, var in d_gradients if grad is not None])
            # max_d_norm = tf.reduce_max(d_norms)

            # # g_clipped_grads = [(tf.clip_by_norm(grad, clip_norm=128), var) for grad, var in g_gradients]
            # # train_gen = optimizer_gen.apply_gradients(g_clipped_grads)
            # gs = t
            # train_gen = optimizer_gen.apply_gradients(g_gradients)
            # train_disc = optimizer_disc.apply_gradients(d_gradients)

        # elif args.optim_strategy == 'alternate':

        #     disc_loss, gp_loss = forward_discriminator(
        #         generator,
        #         discriminator,
        #         real_image_input,
        #         args.latent_dim,
        #         alpha,
        #         phase,
        #         num_phases,
        #         base_dim,
        #         base_shape,
        #         args.activation,
        #         args.leakiness,
        #         args.network_size,
        #         args.loss_fn,
        #         args.gp_weight,
        #         conditioning=real_label
        #     )

        #     # disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
        #     # d_gradients = optimizer_disc.compute_gradients(disc_loss, var_list=disc_vars)
        #     # d_norms = tf.stack([tf.norm(grad) for grad, var in d_gradients if grad is not None])
        #     # max_d_norm = tf.reduce_max(d_norms)

        #     # train_disc = optimizer_disc.apply_gradients(d_gradients)

        #     with tf.control_dependencies([train_disc]):
        #         gen_sample, gen_loss = forward_generator(
        #             generator,
        #             discriminator,
        #             real_image_input,
        #             args.latent_dim,
        #             alpha,
        #             phase,
        #             num_phases,
        #             base_dim,
        #             base_shape,
        #             args.activation,
        #             args.leakiness,
        #             args.network_size,
        #             args.loss_fn,
        #             is_reuse=True
        #         )

        #         gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
        #         g_gradients = optimizer_gen.compute_gradients(gen_loss, var_list=gen_vars)
        #         g_norms = tf.stack([tf.norm(grad) for grad, var in g_gradients if grad is not None])
        #         max_g_norm = tf.reduce_max(g_norms)
        #         train_gen = optimizer_gen.apply_gradients(g_gradients)

        else:
            raise ValueError("Unknown optim strategy ", args.optim_strategy)

        if verbose:
            print(f"Generator parameters: {count_parameters('generator')}")
            print(
                f"Discriminator parameters:: {count_parameters('discriminator')}"
            )

        # train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
        # train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)

        ema = tf.train.ExponentialMovingAverage(decay=args.ema_beta)
        ema_op = ema.apply(gen_vars)
        # Transfer EMA values to original variables
        ema_update_weights = tf.group(
            [tf.assign(var, ema.average(var)) for var in gen_vars])

        with tf.name_scope('summaries'):
            # Summaries
            tf.summary.scalar('d_loss', disc_loss)
            tf.summary.scalar('g_loss', gen_loss)
            tf.summary.scalar('gp', tf.reduce_mean(gp_loss))

            # for g in g_gradients:
            #     tf.summary.histogram(f'grad_{g[1].name}', g[0])

            # for g in d_gradients:
            #     tf.summary.histogram(f'grad_{g[1].name}', g[0])

            # tf.summary.scalar('convergence', tf.reduce_mean(disc_real) - tf.reduce_mean(tf.reduce_mean(disc_fake_d)))

            # tf.summary.scalar('max_g_grad_norm', max_g_norm)
            # tf.summary.scalar('max_d_grad_norm', max_d_norm)

            real_image_grid = tf.transpose(real_image_input,
                                           (0, 2, 3, 1))  # D H W C  -> B H W C
            shape = real_image_grid.get_shape().as_list()
            grid_cols = int(2**np.floor(np.log(np.sqrt(shape[0])) / np.log(2)))
            grid_rows = shape[0] // grid_cols
            grid_shape = [grid_rows, grid_cols]
            real_image_grid = image_grid(real_image_grid,
                                         grid_shape,
                                         image_shape=shape[1:3],
                                         num_channels=shape[-1])

            fake_image_grid = tf.transpose(gen_sample, (0, 2, 3, 1))
            fake_image_grid = image_grid(fake_image_grid,
                                         grid_shape,
                                         image_shape=shape[1:3],
                                         num_channels=shape[-1])

            fake_image_grid = tf.clip_by_value(fake_image_grid, -1, 1)

            tf.summary.image('real_image', real_image_grid)
            tf.summary.image('fake_image', fake_image_grid)

            tf.summary.scalar('fake_image_min', tf.math.reduce_min(gen_sample))
            tf.summary.scalar('fake_image_max', tf.math.reduce_max(gen_sample))

            tf.summary.scalar('real_image_min',
                              tf.math.reduce_min(real_image_input[0]))
            tf.summary.scalar('real_image_max',
                              tf.math.reduce_max(real_image_input[0]))
            tf.summary.scalar('alpha', alpha)

            tf.summary.scalar('g_lr', g_lr)
            tf.summary.scalar('d_lr', d_lr)

            merged_summaries = tf.summary.merge_all()

        # Other ops
        init_op = tf.global_variables_initializer()
        assign_starting_alpha = alpha.assign(args.starting_alpha)
        assign_zero = alpha.assign(0)
        broadcast = hvd.broadcast_global_variables(0)

        with tf.Session(config=config) as sess:
            sess.run(init_op)

            trainable_variable_names = [
                v.name for v in tf.trainable_variables()
            ]

            if var_list is not None and phase > args.starting_phase:
                print("Restoring variables from:",
                      os.path.join(logdir, f'model_{phase - 1}'))
                var_names = [v.name for v in var_list]
                load_vars = [
                    sess.graph.get_tensor_by_name(n) for n in var_names
                    if n in trainable_variable_names
                ]
                saver = tf.train.Saver(load_vars)
                saver.restore(sess, os.path.join(logdir, f'model_{phase - 1}'))
            elif var_list is not None and args.continue_path and phase == args.starting_phase:
                print("Restoring variables from:", args.continue_path)
                var_names = [v.name for v in var_list]
                load_vars = [
                    sess.graph.get_tensor_by_name(n) for n in var_names
                    if n in trainable_variable_names
                ]
                saver = tf.train.Saver(load_vars)
                saver.restore(sess, os.path.join(args.continue_path))
            else:
                if verbose:
                    print("Not restoring variables.")
                    print("Variable List Length:", len(var_list))

            var_list = gen_vars + disc_vars

            if phase < args.starting_phase:
                continue

            if phase == args.starting_phase:
                sess.run(assign_starting_alpha)
            else:
                sess.run(init_alpha)

            if verbose:
                print(f"Begin mixing epochs in phase {phase}")
            if args.horovod:
                sess.run(broadcast)

            local_step = 0
            # take_first_snapshot = True

            while True:
                start = time.time()
                if local_step % 128 == 0 and local_step > 1:
                    if args.horovod:
                        sess.run(broadcast)
                    saver = tf.train.Saver(var_list)
                    if verbose:
                        saver.save(
                            sess,
                            os.path.join(logdir,
                                         f'model_{phase}_ckpt_{global_step}'))

                # _, _, summary, d_loss, g_loss = sess.run(
                #      [train_gen, train_disc, merged_summaries,
                #       disc_loss, gen_loss])

                _, _, summary, d_loss, g_loss = sess.run([
                    optimizer_gen, optimizer_disc, merged_summaries, disc_loss,
                    gen_loss
                ])

                global_step += batch_size * global_size
                local_step += 1

                end = time.time()
                img_s = global_size * batch_size / (end - start)
                if verbose:

                    writer.add_summary(summary, global_step)
                    writer.add_summary(
                        tf.Summary(value=[
                            tf.Summary.Value(tag='img_s', simple_value=img_s)
                        ]), global_step)
                    memory_percentage = psutil.Process(
                        os.getpid()).memory_percent()
                    writer.add_summary(
                        tf.Summary(value=[
                            tf.Summary.Value(tag='memory_percentage',
                                             simple_value=memory_percentage)
                        ]), global_step)

                    print(f"Step {global_step:09} \t"
                          f"img/s {img_s:.2f} \t "
                          f"d_loss {d_loss:.4f} \t "
                          f"g_loss {g_loss:.4f} \t "
                          f"memory {memory_percentage:.4f} % \t"
                          f"alpha {alpha.eval():.2f}")

                    # if take_first_snapshot:
                    #     import tracemalloc
                    #     tracemalloc.start()
                    #     snapshot_first = tracemalloc.take_snapshot()
                    #     take_first_snapshot = False

                    # snapshot = tracemalloc.take_snapshot()
                    # top_stats = snapshot.compare_to(snapshot_first, 'lineno')
                    # print("[ Top 10 differences ]")
                    # for stat in top_stats[:10]:
                    #     print(stat)
                    # snapshot_prev = snapshot

                if global_step >= ((phase - args.starting_phase) *
                                   (args.mixing_nimg + args.stabilizing_nimg) +
                                   args.mixing_nimg):
                    break

                sess.run(update_alpha)
                sess.run(ema_op)
                # sess.run(update_d_lr)
                # sess.run(update_g_lr)

                assert alpha.eval() >= 0

                if verbose:
                    writer.flush()

            if verbose:
                print(f"Begin stabilizing epochs in phase {phase}")

            sess.run(assign_zero)

            while True:
                start = time.time()
                assert alpha.eval() == 0
                if local_step % 128 == 0 and local_step > 0:

                    if args.horovod:
                        sess.run(broadcast)
                    saver = tf.train.Saver(var_list)
                    if verbose:
                        saver.save(
                            sess,
                            os.path.join(logdir,
                                         f'model_{phase}_ckpt_{global_step}'))

                # _, _, summary, d_loss, g_loss = sess.run(
                #      [train_gen, train_disc, merged_summaries,
                #       disc_loss, gen_loss])

                _, _, summary, d_loss, g_loss = sess.run([
                    optimizer_gen, optimizer_disc, merged_summaries, disc_loss,
                    gen_loss
                ])

                global_step += batch_size * global_size
                local_step += 1

                end = time.time()
                img_s = global_size * batch_size / (end - start)
                if verbose:
                    writer.add_summary(
                        tf.Summary(value=[
                            tf.Summary.Value(tag='img_s', simple_value=img_s)
                        ]), global_step)
                    writer.add_summary(summary, global_step)
                    memory_percentage = psutil.Process(
                        os.getpid()).memory_percent()
                    writer.add_summary(
                        tf.Summary(value=[
                            tf.Summary.Value(tag='memory_percentage',
                                             simple_value=memory_percentage)
                        ]), global_step)

                    print(f"Step {global_step:09} \t"
                          f"img/s {img_s:.2f} \t "
                          f"d_loss {d_loss:.4f} \t "
                          f"g_loss {g_loss:.4f} \t "
                          f"memory {memory_percentage:.4f} % \t"
                          f"alpha {alpha.eval():.2f}")

                sess.run(ema_op)

                if verbose:
                    writer.flush()

                if global_step >= (phase - args.starting_phase + 1) * (
                        args.stabilizing_nimg + args.mixing_nimg):
                    # if verbose:
                    #     run_metadata = tf.RunMetadata()
                    #     opts = tf.profiler.ProfileOptionBuilder.float_operation()
                    #     g = tf.get_default_graph()
                    #     flops = tf.profiler.profile(g, run_meta=run_metadata, cmd='op', options=opts)
                    #     writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='graph_flops',
                    #                                                           simple_value=flops.total_float_ops)]),
                    #                        global_step)
                    #
                    #     # Print memory info.
                    #     try:
                    #         print(nvgpu.gpu_info())
                    #     except subprocess.CalledProcessError:
                    #         pid = os.getpid()
                    #         py = psutil.Process(pid)
                    #         print(f"CPU Percent: {py.cpu_percent()}")
                    #         print(f"Memory info: {py.memory_info()}")

                    break

            # # Calculate metrics.
            # calc_swds: bool = size >= 16
            # calc_ssims: bool = min(npy_data.shape[1:]) >= 16
            #
            # if args.calc_metrics:
            #     fids_local = []
            #     swds_local = []
            #     psnrs_local = []
            #     mses_local = []
            #     nrmses_local = []
            #     ssims_local = []
            #
            #     counter = 0
            #     while True:
            #         if args.horovod:
            #             start_loc = counter + hvd.rank() * batch_size
            #         else:
            #             start_loc = 0
            #         real_batch = np.stack([npy_data[i] for i in range(start_loc, start_loc + batch_size)])
            #         real_batch = real_batch.astype(np.int16) - 1024
            #         fake_batch = sess.run(gen_sample).astype(np.float32)
            #
            #         # Turn fake batch into HUs and clip to training range.
            #         fake_batch = (np.clip(fake_batch, -1, 2) * 1024).astype(np.int16)
            #
            #         if verbose:
            #             print('real min, max', real_batch.min(), real_batch.max())
            #             print('fake min, max', fake_batch.min(), fake_batch.max())
            #
            #         fids_local.append(calculate_fid_given_batch_volumes(real_batch, fake_batch, sess))
            #
            #         if calc_swds:
            #             swds = get_swd_for_volumes(real_batch, fake_batch)
            #             swds_local.append(swds)
            #
            #         psnr = get_psnr(real_batch, fake_batch)
            #         if calc_ssims:
            #             ssim = get_ssim(real_batch, fake_batch)
            #             ssims_local.append(ssim)
            #         mse = get_mean_squared_error(real_batch, fake_batch)
            #         nrmse = get_normalized_root_mse(real_batch, fake_batch)
            #
            #         psnrs_local.append(psnr)
            #         mses_local.append(mse)
            #         nrmses_local.append(nrmse)
            #
            #         if args.horovod:
            #             counter = counter + global_size * batch_size
            #         else:
            #             counter += batch_size
            #
            #         if counter >= args.num_metric_samples:
            #             break
            #
            #     fid_local = np.mean(fids_local)
            #     psnr_local = np.mean(psnrs_local)
            #     ssim_local = np.mean(ssims_local)
            #     mse_local = np.mean(mses_local)
            #     nrmse_local = np.mean(nrmses_local)
            #
            #     if args.horovod:
            #         fid = MPI.COMM_WORLD.allreduce(fid_local, op=MPI.SUM) / hvd.size()
            #         psnr = MPI.COMM_WORLD.allreduce(psnr_local, op=MPI.SUM) / hvd.size()
            #         mse = MPI.COMM_WORLD.allreduce(mse_local, op=MPI.SUM) / hvd.size()
            #         nrmse = MPI.COMM_WORLD.allreduce(nrmse_local, op=MPI.SUM) / hvd.size()
            #         if calc_ssims:
            #             ssim = MPI.COMM_WORLD.allreduce(ssim_local, op=MPI.SUM) / hvd.size()
            #     else:
            #         fid = fid_local
            #         psnr = psnr_local
            #         ssim = ssim_local
            #         mse = mse_local
            #         nrmse = nrmse_local
            #
            #     if calc_swds:
            #         swds_local = np.array(swds_local)
            #         # Average over batches
            #         swds_local = swds_local.mean(axis=0)
            #         if args.horovod:
            #             swds = MPI.COMM_WORLD.allreduce(swds_local, op=MPI.SUM) / hvd.size()
            #         else:
            #             swds = swds_local
            #
            #     if verbose:
            #         print(f"FID: {fid:.4f}")
            #         writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='fid',
            #                                                               simple_value=fid)]),
            #                            global_step)
            #
            #         print(f"PSNR: {psnr:.4f}")
            #         writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='psnr',
            #                                                               simple_value=psnr)]),
            #                            global_step)
            #
            #         print(f"MSE: {mse:.4f}")
            #         writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='mse',
            #                                                               simple_value=mse)]),
            #                            global_step)
            #
            #         print(f"Normalized Root MSE: {nrmse:.4f}")
            #         writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='nrmse',
            #                                                               simple_value=nrmse)]),
            #                            global_step)
            #
            #         if calc_swds:
            #             print(f"SWDS: {swds}")
            #             for i in range(len(swds))[:-1]:
            #                 lod = 16 * 2 ** i
            #                 writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=f'swd_{lod}',
            #                                                                       simple_value=swds[
            #                                                                           i])]),
            #                                    global_step)
            #             writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=f'swd_mean',
            #                                                                   simple_value=swds[
            #                                                                       -1])]), global_step)
            #         if calc_ssims:
            #             print(f"SSIM: {ssim}")
            #             writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag=f'ssim',
            #                                                                   simple_value=ssim)]), global_step)

            if verbose:
                print("\n\n\n End of phase.")

                # Save Session.
                sess.run(ema_update_weights)
                saver = tf.train.Saver(var_list)
                saver.save(sess, os.path.join(logdir, f'model_{phase}'))

            if args.ending_phase:
                if phase == args.ending_phase:
                    print("Reached final phase, breaking.")
                    break
Exemplo n.º 4
0
def main():
    args = parser.parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            print(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    r = -1
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        r = torch.distributed.get_rank()

    if args.distributed:
        print(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (r, args.world_size))
    else:
        print('Training with a single process on %d GPUs.' % args.num_gpu)

    # FIXME seed handling for multi-process distributed?
    torch.manual_seed(args.seed)

    output_dir = ''
    if args.local_rank == 0:
        if args.output:
            output_base = args.output
        else:
            output_base = './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(args.img_size)
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         checkpoint_path=args.initial_checkpoint)

    print('Model %s created, param count: %d' %
          (args.model, sum([m.numel() for m in model.parameters()])))

    data_config = resolve_data_config(model,
                                      args,
                                      verbose=args.local_rank == 0)

    # optionally resume from a checkpoint
    start_epoch = 0
    optimizer_state = None
    if args.resume:
        optimizer_state, start_epoch = resume_checkpoint(
            model, args.resume, args.start_epoch)

    if args.num_gpu > 1:
        if args.amp:
            print(
                'Warning: AMP does not work well with nn.DataParallel, disabling. '
                'Use distributed mode for multi-GPU AMP.')
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()
    else:
        model.cuda()

    optimizer = create_optimizer(args, model)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)

    if has_apex and args.amp:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        use_amp = True
        print('AMP enabled')
    else:
        use_amp = False
        print('AMP disabled')

    if args.distributed:
        model = DDP(model, delay_allreduce=True)

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    if start_epoch > 0:
        lr_scheduler.step(start_epoch)
    if args.local_rank == 0:
        print('Scheduled epochs: ', num_epochs)

    train_dir = os.path.join(args.data, 'train')
    if not os.path.exists(train_dir):
        print('Error: training folder does not exist at: %s' % train_dir)
        exit(1)
    dataset_train = Dataset(train_dir)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        rand_erase_prob=args.reprob,
        rand_erase_mode=args.remode,
        interpolation=
        'random',  # FIXME cleanly resolve this? data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
    )

    eval_dir = os.path.join(args.data, 'validation')
    if not os.path.isdir(eval_dir):
        print('Error: validation folder does not exist at: %s' % eval_dir)
        exit(1)
    dataset_eval = Dataset(eval_dir)

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    saver = None
    if output_dir:
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(checkpoint_dir=output_dir,
                                decreasing=decreasing)
    best_metric = None
    best_epoch = None
    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp)

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args)

            if lr_scheduler is not None:
                lr_scheduler.step(epoch, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                best_metric, best_epoch = saver.save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': args.model,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'args': args,
                    },
                    epoch=epoch + 1,
                    metric=eval_metrics[eval_metric])

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        print('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))
def main():
    cfg, args = _parse_args()
    torch.manual_seed(args.seed)

    output_base = cfg.OUTPUT_DIR if len(cfg.OUTPUT_DIR) > 0 else './output'
    exp_name = '-'.join([
        datetime.now().strftime("%Y%m%d-%H%M%S"), cfg.MODEL.ARCHITECTURE,
        str(cfg.INPUT.IMG_SIZE)
    ])
    output_dir = get_outdir(output_base, exp_name)
    with open(os.path.join(output_dir, 'config.yaml'), 'w',
              encoding='utf-8') as file_writer:
        # cfg.dump(stream=file_writer, default_flow_style=False, indent=2, allow_unicode=True)
        file_writer.write(pyaml.dump(cfg))
    logger = setup_logger(file_name=os.path.join(output_dir, 'train.log'),
                          control_log=False,
                          log_level='INFO')

    # create model
    model = create_model(cfg.MODEL.ARCHITECTURE,
                         num_classes=cfg.MODEL.NUM_CLASSES,
                         pretrained=True,
                         in_chans=cfg.INPUT.IN_CHANNELS,
                         drop_rate=cfg.MODEL.DROP_RATE,
                         drop_connect_rate=cfg.MODEL.DROP_CONNECT,
                         global_pool=cfg.MODEL.GLOBAL_POOL)

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    gpu_list = list(map(int, args.gpu.split(',')))
    device = 'cuda'
    if len(gpu_list) == 1:
        model.cuda()
        torch.backends.cudnn.benchmark = True
    elif len(gpu_list) > 1:
        model = nn.DataParallel(model, device_ids=gpu_list)
        model = convert_model(model).cuda()
        torch.backends.cudnn.benchmark = True
    else:
        device = 'cpu'
    logger.info('device: {}, gpu_list: {}'.format(device, gpu_list))

    optimizer = create_optimizer(cfg, model)

    # optionally initialize from a checkpoint
    if args.initial_checkpoint and os.path.isfile(args.initial_checkpoint):
        load_checkpoint(model, args.initial_checkpoint)

    # optionally resume from a checkpoint
    resume_state = None
    resume_epoch = None
    if args.resume and os.path.isfile(args.resume):
        resume_state, resume_epoch = resume_checkpoint(model, args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            optimizer.load_state_dict(resume_state['optimizer'])
            logger.info('Restoring optimizer state from [{}]'.format(
                args.resume))

    start_epoch = 0
    if args.start_epoch is not None:
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch

    model_ema = None
    if cfg.SOLVER.EMA:
        # Important to create EMA model after cuda()
        model_ema = ModelEma(model,
                             decay=cfg.SOLVER.EMA_DECAY,
                             device=device,
                             resume=args.resume)

    lr_scheduler, num_epochs = create_scheduler(cfg, optimizer)
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    # summary
    print('=' * 60)
    print(cfg)
    print('=' * 60)
    print(model)
    print('=' * 60)
    summary(model, (3, cfg.INPUT.IMG_SIZE, cfg.INPUT.IMG_SIZE))

    # dataset
    dataset_train = Dataset(cfg.DATASETS.TRAIN)
    dataset_valid = Dataset(cfg.DATASETS.TEST)
    train_loader = create_loader(dataset_train, cfg, is_training=True)
    valid_loader = create_loader(dataset_valid, cfg, is_training=False)

    # loss function
    if cfg.SOLVER.LABEL_SMOOTHING > 0:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=cfg.SOLVER.LABEL_SMOOTHING).to(device)
        validate_loss_fn = nn.CrossEntropyLoss().to(device)
    else:
        train_loss_fn = nn.CrossEntropyLoss().to(device)
        validate_loss_fn = train_loss_fn

    eval_metric = cfg.SOLVER.EVAL_METRIC
    best_metric = None
    best_epoch = None
    saver = CheckpointSaver(
        checkpoint_dir=output_dir,
        recovery_dir=output_dir,
        decreasing=True if eval_metric == 'loss' else False)
    try:
        for epoch in range(start_epoch, num_epochs):
            train_metrics = train_epoch(epoch,
                                        model,
                                        train_loader,
                                        optimizer,
                                        train_loss_fn,
                                        cfg,
                                        logger,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        device=device,
                                        model_ema=model_ema)

            eval_metrics = validate(epoch, model, valid_loader,
                                    validate_loss_fn, cfg, logger)

            if model_ema is not None:
                ema_eval_metrics = validate(epoch, model_ema.ema, valid_loader,
                                            validate_loss_fn, cfg, logger)
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    cfg,
                    epoch=epoch,
                    model_ema=model_ema,
                    metric=save_metric)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logger.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))