Ejemplo n.º 1
0
def main_worker(gpu, ngpus_per_node, args):
    start_time = datetime.now()
    if len(args.gpu) == 1:
        args.gpu = 0
    else:
        args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:'+args.port,
                                world_size=args.world_size, rank=args.rank)

    if '2' in args.dataset:
        args.output_k = 2

    # # of GT-classes
    args.num_cls = args.output_k

    # Classes to use
    if args.dataset == 'animal_faces':
        args.att_to_use = [11, 43, 56, 74, 89, 128, 130, 138, 140, 141]
    elif args.dataset == 'afhq_cat':
        args.att_to_use = [0, ]
    elif args.dataset == 'afhq_dog':
        args.att_to_use = [1, ]
    elif args.dataset == 'afhq_wild':
        args.att_to_use = [2, ]
    elif '2' in args.dataset:
        args.att_to_use = [0, 1]
        assert args.num_cls == len(args.att_to_use)
    elif args.dataset in ['ffhq', 'lsun_car']:
        args.att_to_use = [0, ]

    # IIC statistics
    args.epoch_acc = []
    args.epoch_avg_subhead_acc = []
    args.epoch_stats = []

    # Logging
    logger = SummaryWriter(args.event_dir)

    # build model - return dict
    networks, opts = build_model(args)

    # load model if args.load_model is specified
    load_model(args, networks, opts)
    cudnn.benchmark = True

    # get dataset and data loader
    train_dataset, val_dataset = get_dataset(args.dataset, args)
    train_loader, val_loader, train_sampler = get_loader(args, {'train': train_dataset, 'val': val_dataset})

    # map the functions to execute - un / sup / semi-
    trainFunc, validationFunc = map_exec_func(args)

    queue_loader = train_loader['UNSUP'] if 0.0 < args.p_semi < 1.0 else train_loader

    queue = initialize_queue(networks['C_EMA'], args.gpu, queue_loader, feat_size=args.sty_dim)

    # print all the argument
    print_args(args)

    # All the test is done in the training - do not need to call
    if args.validation:
        validationFunc(val_loader, networks, 999, args, {'logger': logger, 'queue': queue})
        fid_ema = calcFIDBatch(args, {'VAL': val_loader, 'TRAIN': train_loader}, networks, 'EMA', train_dataset)
        fid_ema_mean = sum(fid_ema) / (len(fid_ema))
        print("Mean FID : [{}] ".format(fid_ema_mean))
        return

    # For saving the model
    if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
        record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+")
        for arg in vars(args):
            record_txt.write('{:35}{:20}\n'.format(arg, str(getattr(args, arg))))
        record_txt.close()

    # Run
    validationFunc(val_loader, networks, 0, args, {'logger': logger, 'queue': queue})

    fid_best_ema = 999.0

    for epoch in range(args.start_epoch, args.epochs):
        tme = (datetime.now() - start_time).seconds / 3600

        # Timeout training
        if tme >= args.timeout:
            save_model(args, epoch, networks, opts)
            print("Training reached timeout")
            break

        print("START EPOCH[{}]".format(epoch+1))
        if (epoch + 1) % (args.epochs // 10) == 0:
            save_model(args, epoch, networks, opts)

        if args.distributed:
            if 0.0 < args.p_semi < 1.0:
                assert 'SEMI' in args.train_mode
                train_sampler['SUP'].set_epoch(epoch)
                train_sampler['UNSUP'].set_epoch(epoch)
            else:
                train_sampler.set_epoch(epoch)

        if epoch == args.ema_start and 'GAN' in args.train_mode:
            if args.distributed:
                networks['G_EMA'].module.load_state_dict(networks['G'].module.state_dict())
            else:
                networks['G_EMA'].load_state_dict(networks['G'].state_dict())

        trainFunc(train_loader, networks, opts, epoch, args, {'logger': logger, 'queue': queue})

        validationFunc(val_loader, networks, epoch, args, {'logger': logger, 'queue': queue})

        # Calc fid
        if epoch >= args.fid_start and args.dataset not in ['ffhq', 'lsun_car', 'afhq_cat', 'afhq_dog', 'afhq_wild']:
            fid_ema = calcFIDBatch(args, {'VAL': val_loader, 'TRAIN': train_loader}, networks, 'EMA', train_dataset)
            fid_ema_mean = sum(fid_ema) / (len(fid_ema))

            if fid_best_ema > fid_ema_mean:
                fid_best_ema = fid_ema_mean
                save_model(args, 4567, networks, opts)

            print("Mean FID : [{}] AT EPOCH[{}] G_EMA / BEST EVER[{}]".format(fid_ema_mean, epoch + 1, fid_best_ema))

        # Write logs
        if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % args.ngpus_per_node == 0):
            if (epoch + 1) % 10 == 0:
                save_model(args, epoch, networks, opts)
            if not args.train_mode in ['CLS_UN', 'CLS_SEMI']:
                if epoch >= args.fid_start and args.dataset not in ['ffhq', 'lsun_car']:
                    for idx_fid in range(len(args.att_to_use)):
                        add_logs(args, logger, 'STATEMA/G_EMA{}/FID'.format(idx_fid), fid_ema[idx_fid], epoch + 1)
                    add_logs(args, logger, 'STATEMA/G_EMA/mFID', fid_ema_mean, epoch + 1)
            if len(args.epoch_acc) > 0:
                add_logs(args, logger, 'STATC/Acc', float(args.epoch_acc[-1]), epoch + 1)
Ejemplo n.º 2
0
def main_worker(gpu, ngpus_per_node, args):
    if len(args.gpu) == 1:
        args.gpu = 0
    else:
        args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend='nccl',
                                init_method='tcp://127.0.0.1:' + args.port,
                                world_size=args.world_size,
                                rank=args.rank)

    ################
    # Define model #
    ################
    # 4/3 : scale factor in the paper
    scale_factor = 4 / 3
    tmp_scale = args.img_size_max / args.img_size_min
    args.num_scale = int(np.round(np.log(tmp_scale) / np.log(scale_factor)))
    args.size_list = [
        int(args.img_size_min * scale_factor**i)
        for i in range(args.num_scale + 1)
    ]

    discriminator = Discriminator()
    generator = Generator(args.img_size_min, args.num_scale, scale_factor)

    networks = [discriminator, generator]

    if args.distributed:
        if args.gpu is not None:
            print('Distributed to', args.gpu)
            torch.cuda.set_device(args.gpu)
            networks = [x.cuda(args.gpu) for x in networks]
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            networks = [
                torch.nn.parallel.DistributedDataParallel(
                    x, device_ids=[args.gpu], output_device=args.gpu)
                for x in networks
            ]
        else:
            networks = [x.cuda() for x in networks]
            networks = [
                torch.nn.parallel.DistributedDataParallel(x) for x in networks
            ]

    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        networks = [x.cuda(args.gpu) for x in networks]
    else:
        networks = [torch.nn.DataParallel(x).cuda() for x in networks]

    discriminator, generator, = networks

    ######################
    # Loss and Optimizer #
    ######################
    if args.distributed:
        d_opt = torch.optim.Adam(
            discriminator.module.sub_discriminators[0].parameters(), 5e-4,
            (0.5, 0.999))
        g_opt = torch.optim.Adam(
            generator.module.sub_generators[0].parameters(), 5e-4,
            (0.5, 0.999))
    else:
        d_opt = torch.optim.Adam(
            discriminator.sub_discriminators[0].parameters(), 5e-4,
            (0.5, 0.999))
        g_opt = torch.optim.Adam(generator.sub_generators[0].parameters(),
                                 5e-4, (0.5, 0.999))

    ##############
    # Load model #
    ##############
    args.stage = 0
    if args.load_model is not None:
        check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r')
        to_restore = check_load.readlines()[-1].strip()
        load_file = os.path.join(args.log_dir, to_restore)
        if os.path.isfile(load_file):
            print("=> loading checkpoint '{}'".format(load_file))
            checkpoint = torch.load(load_file, map_location='cpu')
            for _ in range(int(checkpoint['stage'])):
                generator.progress()
                discriminator.progress()
            networks = [discriminator, generator]

            if args.distributed:
                if args.gpu is not None:
                    print('Distributed to', args.gpu)
                    torch.cuda.set_device(args.gpu)
                    networks = [x.cuda(args.gpu) for x in networks]
                    args.batch_size = int(args.batch_size / ngpus_per_node)
                    args.workers = int(args.workers / ngpus_per_node)
                    networks = [
                        torch.nn.parallel.DistributedDataParallel(
                            x, device_ids=[args.gpu], output_device=args.gpu)
                        for x in networks
                    ]
                else:
                    networks = [x.cuda() for x in networks]
                    networks = [
                        torch.nn.parallel.DistributedDataParallel(x)
                        for x in networks
                    ]

            elif args.gpu is not None:
                torch.cuda.set_device(args.gpu)
                networks = [x.cuda(args.gpu) for x in networks]
            else:
                networks = [torch.nn.DataParallel(x).cuda() for x in networks]

            discriminator, generator, = networks

            args.stage = checkpoint['stage']
            args.img_to_use = checkpoint['img_to_use']
            discriminator.load_state_dict(checkpoint['D_state_dict'])
            generator.load_state_dict(checkpoint['G_state_dict'])
            d_opt.load_state_dict(checkpoint['d_optimizer'])
            g_opt.load_state_dict(checkpoint['g_optimizer'])
            print("=> loaded checkpoint '{}' (stage {})".format(
                load_file, checkpoint['stage']))
        else:
            print("=> no checkpoint found at '{}'".format(args.log_dir))

    cudnn.benchmark = True

    ###########
    # Dataset #
    ###########
    train_dataset, _ = get_dataset(args.dataset, args)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    ######################
    # Validate and Train #
    ######################
    z_fix_list = [
        F.pad(torch.randn(args.batch_size, 3, args.size_list[0],
                          args.size_list[0]), [5, 5, 5, 5],
              value=0)
    ]
    zero_list = [
        F.pad(torch.zeros(args.batch_size, 3, args.size_list[zeros_idx],
                          args.size_list[zeros_idx]), [5, 5, 5, 5],
              value=0) for zeros_idx in range(1, args.num_scale + 1)
    ]
    z_fix_list = z_fix_list + zero_list

    if args.validation:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    elif args.test:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    if not args.multiprocessing_distributed or (
            args.multiprocessing_distributed
            and args.rank % ngpus_per_node == 0):
        check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+")
        record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+")
        record_txt.write('DATASET\t:\t{}\n'.format(args.dataset))
        record_txt.write('GANTYPE\t:\t{}\n'.format(args.gantype))
        record_txt.write('IMGTOUSE\t:\t{}\n'.format(args.img_to_use))
        record_txt.close()

    for stage in range(args.stage, args.num_scale + 1):
        if args.distributed:
            train_sampler.set_epoch(stage)

        trainSinGAN(train_loader, networks, {
            "d_opt": d_opt,
            "g_opt": g_opt
        }, stage, args, {"z_rec": z_fix_list})
        validateSinGAN(train_loader, networks, stage, args,
                       {"z_rec": z_fix_list})

        if args.distributed:
            discriminator.module.progress()
            generator.module.progress()
        else:
            discriminator.progress()
            generator.progress()

        networks = [discriminator, generator]

        if args.distributed:
            if args.gpu is not None:
                print('Distributed', args.gpu)
                torch.cuda.set_device(args.gpu)
                networks = [x.cuda(args.gpu) for x in networks]
                args.batch_size = int(args.batch_size / ngpus_per_node)
                args.workers = int(args.workers / ngpus_per_node)
                networks = [
                    torch.nn.parallel.DistributedDataParallel(
                        x, device_ids=[args.gpu], output_device=args.gpu)
                    for x in networks
                ]
            else:
                networks = [x.cuda() for x in networks]
                networks = [
                    torch.nn.parallel.DistributedDataParallel(x)
                    for x in networks
                ]

        elif args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            networks = [x.cuda(args.gpu) for x in networks]
        else:
            networks = [torch.nn.DataParallel(x).cuda() for x in networks]

        discriminator, generator, = networks

        # Update the networks at finest scale
        if args.distributed:
            for net_idx in range(generator.module.current_scale):
                for param in generator.module.sub_generators[
                        net_idx].parameters():
                    param.requires_grad = False
                for param in discriminator.module.sub_discriminators[
                        net_idx].parameters():
                    param.requires_grad = False

            d_opt = torch.optim.Adam(
                discriminator.module.sub_discriminators[
                    discriminator.current_scale].parameters(), 5e-4,
                (0.5, 0.999))
            g_opt = torch.optim.Adam(
                generator.module.sub_generators[
                    generator.current_scale].parameters(), 5e-4, (0.5, 0.999))
        else:
            for net_idx in range(generator.current_scale):
                for param in generator.sub_generators[net_idx].parameters():
                    param.requires_grad = False
                for param in discriminator.sub_discriminators[
                        net_idx].parameters():
                    param.requires_grad = False

            d_opt = torch.optim.Adam(
                discriminator.sub_discriminators[
                    discriminator.current_scale].parameters(), 5e-4,
                (0.5, 0.999))
            g_opt = torch.optim.Adam(
                generator.sub_generators[generator.current_scale].parameters(),
                5e-4, (0.5, 0.999))

        ##############
        # Save model #
        ##############
        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            if stage == 0:
                check_list = open(os.path.join(args.log_dir, "checkpoint.txt"),
                                  "a+")
            save_checkpoint(
                {
                    'stage': stage + 1,
                    'D_state_dict': discriminator.state_dict(),
                    'G_state_dict': generator.state_dict(),
                    'd_optimizer': d_opt.state_dict(),
                    'g_optimizer': g_opt.state_dict(),
                    'img_to_use': args.img_to_use
                }, check_list, args.log_dir, stage + 1)
            if stage == args.num_scale:
                check_list.close()
Ejemplo n.º 3
0
def main_worker(gpu, ngpus_per_node, args):
    if len(args.gpu) == 1:
        args.gpu = 0
    else:
        args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:'+args.port,
                                world_size=args.world_size, rank=args.rank)

    # # of GT-classes
    args.num_cls = args.output_k

    # Classes to use
    args.att_to_use = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399]

    # IIC statistics
    args.epoch_acc = []
    args.epoch_avg_subhead_acc = []
    args.epoch_stats = []

    # Logging
    logger = SummaryWriter(args.event_dir)

    # build model - return dict
    networks, opts = build_model(args)

    # load model if args.load_model is specified
    load_model(args, networks, opts)
    cudnn.benchmark = True

    # get dataset and data loader
    train_dataset, val_dataset = get_dataset(args)
    train_loader, val_loader, train_sampler = get_loader(args, {'train': train_dataset, 'val': val_dataset})

    # map the functions to execute - un / sup / semi-
    trainFunc, validationFunc = map_exec_func(args)

    # print all the argument
    print_args(args)

    # All the test is done in the training - do not need to call
    if args.validation:
        validationFunc(val_loader, networks, 999, args, {'logger': logger})
        return

    # For saving the model
    if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
        record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+")
        for arg in vars(args):
            record_txt.write('{:35}{:20}\n'.format(arg, str(getattr(args, arg))))
        record_txt.close()

    # Run
    #validationFunc(val_loader, networks, 0, args, {'logger': logger, 'queue': queue})

    for epoch in range(args.start_epoch, args.epochs):
        print("START EPOCH[{}]".format(epoch+1))
        if (epoch + 1) % (args.epochs // 10) == 0:
            save_model(args, epoch, networks, opts)

        if args.distributed:
            train_sampler.set_epoch(epoch)

        if epoch == args.ema_start and 'GAN' in args.train_mode:
            if args.distributed:
                networks['G_EMA'].module.load_state_dict(networks['G'].module.state_dict())
            else:
                networks['G_EMA'].load_state_dict(networks['G'].state_dict())

        trainFunc(train_loader, networks, opts, epoch, args, {'logger': logger})

        validationFunc(val_loader, networks, epoch, args, {'logger': logger})
Ejemplo n.º 4
0
def main_worker(args):

    ################
    # Define model #
    ################
    # 4/3 : scale factor in the paper
    scale_factor = 4 / 3
    tmp_scale = args.img_size_max / args.img_size_min
    args.num_scale = int(np.round(np.log(tmp_scale) / np.log(scale_factor)))
    args.size_list = [
        int(args.img_size_min * scale_factor**i)
        for i in range(args.num_scale + 1)
    ]

    discriminator = Discriminator()
    generator = Generator(args.img_size_min, args.num_scale, scale_factor)

    ######################
    # Loss and Optimizer #
    ######################
    d_opt = mindspore.nn.Adam(
        discriminator.sub_discriminators[0].get_parameters(), 5e-4, 0.5, 0.999)
    g_opt = mindspore.nn.Adam(generator.sub_generators[0].get_parameters(),
                              5e-4, 0.5, 0.999)

    ##############
    # Load model #
    ##############
    args.stage = 0
    if args.load_model is not None:
        check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r')
        to_restore = check_load.readlines()[-1].strip()
        load_file = os.path.join(args.log_dir, to_restore)
        if os.path.isfile(load_file):
            print("=> loading checkpoint '{}'".format(load_file))
            checkpoint = mindspore.load_checkpoint(
                load_file)  # MPS map_location='cpu'#
            for _ in range(int(checkpoint['stage'])):
                generator.progress()
                discriminator.progress()
            args.stage = checkpoint['stage']
            args.img_to_use = checkpoint['img_to_use']
            discriminator.load_state_dict(checkpoint['D_state_dict'])
            generator.load_state_dict(checkpoint['G_state_dict'])
            # MPS Adm.load_state_dict是否存在
            d_opt.load_state_dict(checkpoint['d_optimizer'])
            g_opt.load_state_dict(checkpoint['g_optimizer'])
            print("=> loaded checkpoint '{}' (stage {})".format(
                load_file, checkpoint['stage']))
        else:
            print("=> no checkpoint found at '{}'".format(args.log_dir))

    ###########
    # Dataset #
    ###########
    train_dataset, _ = get_dataset(args.dataset, args)
    train_sampler = None

    train_loader = mindspore.DatasetHelper(train_dataset)  # MPS 可能需要调参数

    ######################
    # Validate and Train #
    ######################
    op1 = mindspore.ops.Pad(((5, 5), (5, 5)))
    op2 = mindspore.ops.Pad(((5, 5), (5, 5)))
    z_fix_list = [op1(mindspore.ops.StandardNormal(3, args.size_list[0]))]
    zero_list = [
        op2(mindspore.ops.Zeros(3, args.size_list[zeros_idx]))
        for zeros_idx in range(1, args.num_scale + 1)
    ]
    z_fix_list = z_fix_list + zero_list

    if args.validation:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    elif args.test:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+")
    record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+")
    record_txt.write('DATASET\t:\t{}\n'.format(args.dataset))
    record_txt.write('GANTYPE\t:\t{}\n'.format(args.gantype))
    record_txt.write('IMGTOUSE\t:\t{}\n'.format(args.img_to_use))
    record_txt.close()
    networks = [discriminator, generator]

    for stage in range(args.stage, args.num_scale + 1):

        trainSinGAN(train_loader, networks, {
            "d_opt": d_opt,
            "g_opt": g_opt
        }, stage, args, {"z_rec": z_fix_list})
        validateSinGAN(train_loader, networks, stage, args,
                       {"z_rec": z_fix_list})
        discriminator.progress()
        generator.progress()

        # Update the networks at finest scale
        d_opt = mindspore.nn.Adam(
            discriminator.sub_discriminators[
                discriminator.current_scale].parameters(), 5e-4, 0.5, 0.999)
        g_opt = mindspore.nn.Adam(
            generator.sub_generators[generator.current_scale].parameters(),
            5e-4, 0.5, 0.999)
        ##############
        # Save model #
        ##############
        if stage == 0:
            check_list = open(os.path.join(args.log_dir, "checkpoint.txt"),
                              "a+")
        save_checkpoint(
            {
                'stage': stage + 1,
                'D_state_dict': discriminator.state_dict(),
                'G_state_dict': generator.state_dict(),
                'd_optimizer': d_opt.state_dict(),
                'g_optimizer': g_opt.state_dict(),
                'img_to_use': args.img_to_use
            }, check_list, args.log_dir, stage + 1)
        if stage == args.num_scale:
            check_list.close()
Ejemplo n.º 5
0
def main_worker(gpu, ngpus_per_node, args):
    if len(args.gpu) == 1:
        args.gpu = 0
    else:
        args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend='nccl',
                                init_method='tcp://127.0.0.1:' + args.port,
                                world_size=args.world_size,
                                rank=args.rank)

    ################
    # Define model #
    ################
    # ['rotation', 'translation', 'shear', 'hflip', 'scale']
    tmptftypes = []
    if 'R' in args.tftypes:
        tmptftypes.append('rotation')
    if 'T' in args.tftypes:
        tmptftypes.append('translation')
    if 'S' in args.tftypes:
        tmptftypes.append('shear')
    if 'H' in args.tftypes:
        tmptftypes.append('hflip')
    if 'C' in args.tftypes:
        tmptftypes.append('scale')
    if 'O' in args.tftypes:
        tmptftypes.append('odd')

    args.tftypes_org = args.tftypes
    args.tftypes = tmptftypes
    args.tfnums = [4, 3, 3, 2, 3, 5]
    args.tfval = {'T': 0.1, 'C': 0.3, 'S': 30, 'O': 3.0}

    print("=> Creating Classifier")
    if 'tf' in args.method:
        print('=> Use transform:\t', args.tftypes)
        print('=> Use nums:\t', args.tfnums)
        print('=> Use vals:\t', args.tfval)
        if 'vgg' in args.network:
            print("USE VGGCAMTF")
            classifier = vggcamtf(args.network,
                                  args.tftypes,
                                  args.tfnums,
                                  pretrained=True)
        elif 'seres' in args.network:
            print('USE SERESNET')
            classifier = serescamtf(args.network,
                                    args.tftypes,
                                    args.tfnums,
                                    pretrained=True)
        elif 'res' in args.network:
            print('USE RESCAMTF')
            classifier = rescamtf(args.network,
                                  args.tftypes,
                                  args.tfnums,
                                  pretrained=True)
        else:
            print("NOT IMPLEMENTED")
            return

    else:
        if args.dataset.lower() == 'imagenet':
            if 'vgg' in args.network:
                print('USE VGG')
                assert args.network in ['vggimg16']
                classifier = vggimg(args.network)
        else:
            if 'vgg' in args.network:
                print('USE VGG')
                classifier = vggcam(True,
                                    args.network,
                                    args.method,
                                    dataset=args.dataset)
            elif args.network == 'res':
                print('USE RES')
                print("NOT IMPLEMENTED")
                return
            else:
                print("NOT IMPLEMENTED")
                return

    networks = [classifier]

    if args.distributed:
        if args.gpu is not None:
            print('Distributed', args.gpu)
            torch.cuda.set_device(args.gpu)
            networks = [x.cuda(args.gpu) for x in networks]
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            networks = [
                torch.nn.parallel.DistributedDataParallel(
                    x, device_ids=[args.gpu], output_device=args.gpu)
                for x in networks
            ]
        else:
            networks = [x.cuda() for x in networks]
            networks = [
                torch.nn.parallel.DistributedDataParallel(x) for x in networks
            ]

    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        networks = [x.cuda(args.gpu) for x in networks]
    else:
        networks = [torch.nn.DataParallel(x).cuda() for x in networks]

    classifier, = networks
    ######################
    # Loss and Optimizer #
    ######################
    if args.distributed:
        c_opt = torch.optim.SGD(classifier.module.parameters(),
                                0.001,
                                momentum=0.9,
                                nesterov=True)
    else:
        c_opt = torch.optim.SGD(classifier.parameters(),
                                0.001,
                                momentum=0.9,
                                nesterov=True)

    ##############
    # Load model #
    ##############
    if args.load_model is not None:
        check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r')
        to_restore = check_load.readlines()[-1].strip()
        load_file = os.path.join(args.log_dir, to_restore)
        if os.path.isfile(load_file):
            print("=> loading checkpoint '{}'".format(load_file))
            checkpoint = torch.load(load_file, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            classifier.load_state_dict(checkpoint['C_state_dict'])
            c_opt.load_state_dict(checkpoint['c_optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                load_file, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.log_dir))

    cudnn.benchmark = True

    ###########
    # Dataset #
    ###########
    train_dataset, val_dataset = get_dataset(args.dataset, args)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.val_batch,
                                             num_workers=args.workers,
                                             pin_memory=True)

    ######################
    # Validate and Train #
    ######################
    if args.validation:
        if 'tf' in args.method:
            validateTF(val_loader,
                       networks,
                       678,
                       args,
                       True,
                       additional={"dataset": val_dataset})
        else:
            if args.dataset.lower() == 'imagenet':
                validateImage(val_loader, networks, 123, args, True)
            else:
                validateFull(val_loader, networks, 123, args, True)
        return

    elif args.test:
        if 'tf' in args.method:
            validateTF(val_loader,
                       networks,
                       456,
                       args,
                       True,
                       additional={"dataset": val_dataset})
        else:
            if args.dataset.lower() == 'imagenet':
                validateImage(val_loader, networks, 456, args, True)
            else:
                validateFull(val_loader, networks, 456, args, True)
        return

    if not args.multiprocessing_distributed or (
            args.multiprocessing_distributed
            and args.rank % ngpus_per_node == 0):
        check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+")
        record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+")
        record_txt.write('Network\t:\t{}\n'.format(args.network))
        record_txt.write('Method\t:\t{}\n'.format(args.method))
        record_txt.write('DATASET\t:\t{}\n'.format(args.dataset))
        record_txt.write('TFTYPES\t:\t{}\n'.format(args.tftypes))
        record_txt.write('TFNUMS\t:\t{}\n'.format(args.tfnums))
        record_txt.write('TFVALS\t:\t{}\n'.format(args.tfval))
        record_txt.close()

    best = 0.0
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        if epoch in [
                60,
        ]:
            for param_group in c_opt.param_groups:
                param_group['lr'] *= 0.2
                print(param_group['lr'])

        if 'tf' in args.method:
            acc_train = trainTF(train_loader, networks, [c_opt], epoch, args,
                                [None])
            acc_val = validateTF(val_loader,
                                 networks,
                                 epoch,
                                 args,
                                 saveimgs=False,
                                 additional={"dataset": val_dataset})
            best_criterion = acc_val['GT']
        else:
            acc_train = trainFull(train_loader, networks, [c_opt], epoch, args,
                                  [None])
            acc_val = validateFull(val_loader,
                                   networks,
                                   epoch,
                                   args,
                                   saveimgs=False)
            best_criterion = acc_val['top1loc']

        ##############
        # Save model #
        ##############
        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            if epoch == 0:
                check_list = open(os.path.join(args.log_dir, "checkpoint.txt"),
                                  "a+")
            if best < best_criterion:
                best = best_criterion
                best_ckpt = [
                    f for f in glob(os.path.join(args.log_dir, "*.ckpt"))
                    if "best" in f
                ]
                if len(best_ckpt) != 0:
                    os.remove(best_ckpt[0])
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'C_state_dict': classifier.state_dict(),
                        'c_optimizer': c_opt.state_dict(),
                    }, check_list, args.log_dir,
                    'best' + str(best).replace(".", "_"))
            if (epoch + 1) % (args.epochs // 10) == 0:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'C_state_dict': classifier.state_dict(),
                        'c_optimizer': c_opt.state_dict(),
                    }, check_list, args.log_dir, epoch + 1)
            if epoch == (args.epochs - 1):
                check_list.close()
            record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+")
            if 'tf' in args.method:
                record_txt.write(
                    'Epoch : {:3d}, Train R : {:.4f} T : {:.4f} S : {:.4f} H : {:.4f} C : {:.4f} O : {:.4f},'
                    ' VAL R : {:.4f} T : {:.4f} S : {:.4f} H : {:.4f} C : {:.4f} O : {:.4f}, GT : {:.4f}\n'
                    .format(epoch, acc_train['R'], acc_train['T'],
                            acc_train['S'], acc_train['H'], acc_train['C'],
                            acc_train['O'], acc_val['R'], acc_val['T'],
                            acc_val['S'], acc_val['H'], acc_val['C'],
                            acc_val['O'], acc_val['GT']))
            else:
                record_txt.write(
                    'Epoch : {:3d}, Train top1 : {:.4f} top5 : {:.4f} ,'
                    ' VAL top1 : {:.4f} top5 : {:.4f} LOC : {:.4f} GT : {:.4f}\n'
                    .format(epoch, acc_train['top1acc'], acc_train['top5acc'],
                            acc_val['top1acc'], acc_val['top5acc'],
                            acc_val['top1loc'], acc_val['gtknown']))

            record_txt.close()
            copyfile(os.path.join(args.log_dir, "record.txt"),
                     os.path.join(args.log_dir, "recordNOW.txt"))

            print('BEST LOC EVER : {:.3f}'.format(best))