Beispiel #1
0
def main_worker(gpu, ngpus_per_node, args):
    if len(args.gpu) == 1:
        args.gpu = 0
    else:
        args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend='nccl',
                                init_method='tcp://127.0.0.1:' + args.port,
                                world_size=args.world_size,
                                rank=args.rank)

    ################
    # Define model #
    ################
    # 4/3 : scale factor in the paper
    scale_factor = 4 / 3
    tmp_scale = args.img_size_max / args.img_size_min
    args.num_scale = int(np.round(np.log(tmp_scale) / np.log(scale_factor)))
    args.size_list = [
        int(args.img_size_min * scale_factor**i)
        for i in range(args.num_scale + 1)
    ]

    discriminator = Discriminator()
    generator = Generator(args.img_size_min, args.num_scale, scale_factor)

    networks = [discriminator, generator]

    if args.distributed:
        if args.gpu is not None:
            print('Distributed to', args.gpu)
            torch.cuda.set_device(args.gpu)
            networks = [x.cuda(args.gpu) for x in networks]
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            networks = [
                torch.nn.parallel.DistributedDataParallel(
                    x, device_ids=[args.gpu], output_device=args.gpu)
                for x in networks
            ]
        else:
            networks = [x.cuda() for x in networks]
            networks = [
                torch.nn.parallel.DistributedDataParallel(x) for x in networks
            ]

    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        networks = [x.cuda(args.gpu) for x in networks]
    else:
        networks = [torch.nn.DataParallel(x).cuda() for x in networks]

    discriminator, generator, = networks

    ######################
    # Loss and Optimizer #
    ######################
    if args.distributed:
        d_opt = torch.optim.Adam(
            discriminator.module.sub_discriminators[0].parameters(), 5e-4,
            (0.5, 0.999))
        g_opt = torch.optim.Adam(
            generator.module.sub_generators[0].parameters(), 5e-4,
            (0.5, 0.999))
    else:
        d_opt = torch.optim.Adam(
            discriminator.sub_discriminators[0].parameters(), 5e-4,
            (0.5, 0.999))
        g_opt = torch.optim.Adam(generator.sub_generators[0].parameters(),
                                 5e-4, (0.5, 0.999))

    ##############
    # Load model #
    ##############
    args.stage = 0
    if args.load_model is not None:
        check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r')
        to_restore = check_load.readlines()[-1].strip()
        load_file = os.path.join(args.log_dir, to_restore)
        if os.path.isfile(load_file):
            print("=> loading checkpoint '{}'".format(load_file))
            checkpoint = torch.load(load_file, map_location='cpu')
            for _ in range(int(checkpoint['stage'])):
                generator.progress()
                discriminator.progress()
            networks = [discriminator, generator]

            if args.distributed:
                if args.gpu is not None:
                    print('Distributed to', args.gpu)
                    torch.cuda.set_device(args.gpu)
                    networks = [x.cuda(args.gpu) for x in networks]
                    args.batch_size = int(args.batch_size / ngpus_per_node)
                    args.workers = int(args.workers / ngpus_per_node)
                    networks = [
                        torch.nn.parallel.DistributedDataParallel(
                            x, device_ids=[args.gpu], output_device=args.gpu)
                        for x in networks
                    ]
                else:
                    networks = [x.cuda() for x in networks]
                    networks = [
                        torch.nn.parallel.DistributedDataParallel(x)
                        for x in networks
                    ]

            elif args.gpu is not None:
                torch.cuda.set_device(args.gpu)
                networks = [x.cuda(args.gpu) for x in networks]
            else:
                networks = [torch.nn.DataParallel(x).cuda() for x in networks]

            discriminator, generator, = networks

            args.stage = checkpoint['stage']
            args.img_to_use = checkpoint['img_to_use']
            discriminator.load_state_dict(checkpoint['D_state_dict'])
            generator.load_state_dict(checkpoint['G_state_dict'])
            d_opt.load_state_dict(checkpoint['d_optimizer'])
            g_opt.load_state_dict(checkpoint['g_optimizer'])
            print("=> loaded checkpoint '{}' (stage {})".format(
                load_file, checkpoint['stage']))
        else:
            print("=> no checkpoint found at '{}'".format(args.log_dir))

    cudnn.benchmark = True

    ###########
    # Dataset #
    ###########
    train_dataset, _ = get_dataset(args.dataset, args)

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    ######################
    # Validate and Train #
    ######################
    z_fix_list = [
        F.pad(torch.randn(args.batch_size, 3, args.size_list[0],
                          args.size_list[0]), [5, 5, 5, 5],
              value=0)
    ]
    zero_list = [
        F.pad(torch.zeros(args.batch_size, 3, args.size_list[zeros_idx],
                          args.size_list[zeros_idx]), [5, 5, 5, 5],
              value=0) for zeros_idx in range(1, args.num_scale + 1)
    ]
    z_fix_list = z_fix_list + zero_list

    if args.validation:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    elif args.test:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    if not args.multiprocessing_distributed or (
            args.multiprocessing_distributed
            and args.rank % ngpus_per_node == 0):
        check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+")
        record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+")
        record_txt.write('DATASET\t:\t{}\n'.format(args.dataset))
        record_txt.write('GANTYPE\t:\t{}\n'.format(args.gantype))
        record_txt.write('IMGTOUSE\t:\t{}\n'.format(args.img_to_use))
        record_txt.close()

    for stage in range(args.stage, args.num_scale + 1):
        if args.distributed:
            train_sampler.set_epoch(stage)

        trainSinGAN(train_loader, networks, {
            "d_opt": d_opt,
            "g_opt": g_opt
        }, stage, args, {"z_rec": z_fix_list})
        validateSinGAN(train_loader, networks, stage, args,
                       {"z_rec": z_fix_list})

        if args.distributed:
            discriminator.module.progress()
            generator.module.progress()
        else:
            discriminator.progress()
            generator.progress()

        networks = [discriminator, generator]

        if args.distributed:
            if args.gpu is not None:
                print('Distributed', args.gpu)
                torch.cuda.set_device(args.gpu)
                networks = [x.cuda(args.gpu) for x in networks]
                args.batch_size = int(args.batch_size / ngpus_per_node)
                args.workers = int(args.workers / ngpus_per_node)
                networks = [
                    torch.nn.parallel.DistributedDataParallel(
                        x, device_ids=[args.gpu], output_device=args.gpu)
                    for x in networks
                ]
            else:
                networks = [x.cuda() for x in networks]
                networks = [
                    torch.nn.parallel.DistributedDataParallel(x)
                    for x in networks
                ]

        elif args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            networks = [x.cuda(args.gpu) for x in networks]
        else:
            networks = [torch.nn.DataParallel(x).cuda() for x in networks]

        discriminator, generator, = networks

        # Update the networks at finest scale
        if args.distributed:
            for net_idx in range(generator.module.current_scale):
                for param in generator.module.sub_generators[
                        net_idx].parameters():
                    param.requires_grad = False
                for param in discriminator.module.sub_discriminators[
                        net_idx].parameters():
                    param.requires_grad = False

            d_opt = torch.optim.Adam(
                discriminator.module.sub_discriminators[
                    discriminator.current_scale].parameters(), 5e-4,
                (0.5, 0.999))
            g_opt = torch.optim.Adam(
                generator.module.sub_generators[
                    generator.current_scale].parameters(), 5e-4, (0.5, 0.999))
        else:
            for net_idx in range(generator.current_scale):
                for param in generator.sub_generators[net_idx].parameters():
                    param.requires_grad = False
                for param in discriminator.sub_discriminators[
                        net_idx].parameters():
                    param.requires_grad = False

            d_opt = torch.optim.Adam(
                discriminator.sub_discriminators[
                    discriminator.current_scale].parameters(), 5e-4,
                (0.5, 0.999))
            g_opt = torch.optim.Adam(
                generator.sub_generators[generator.current_scale].parameters(),
                5e-4, (0.5, 0.999))

        ##############
        # Save model #
        ##############
        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            if stage == 0:
                check_list = open(os.path.join(args.log_dir, "checkpoint.txt"),
                                  "a+")
            save_checkpoint(
                {
                    'stage': stage + 1,
                    'D_state_dict': discriminator.state_dict(),
                    'G_state_dict': generator.state_dict(),
                    'd_optimizer': d_opt.state_dict(),
                    'g_optimizer': g_opt.state_dict(),
                    'img_to_use': args.img_to_use
                }, check_list, args.log_dir, stage + 1)
            if stage == args.num_scale:
                check_list.close()
Beispiel #2
0
def main_worker(args):

    ################
    # Define model #
    ################
    # 4/3 : scale factor in the paper
    scale_factor = 4 / 3
    tmp_scale = args.img_size_max / args.img_size_min
    args.num_scale = int(np.round(np.log(tmp_scale) / np.log(scale_factor)))
    args.size_list = [
        int(args.img_size_min * scale_factor**i)
        for i in range(args.num_scale + 1)
    ]

    discriminator = Discriminator()
    generator = Generator(args.img_size_min, args.num_scale, scale_factor)

    ######################
    # Loss and Optimizer #
    ######################
    d_opt = mindspore.nn.Adam(
        discriminator.sub_discriminators[0].get_parameters(), 5e-4, 0.5, 0.999)
    g_opt = mindspore.nn.Adam(generator.sub_generators[0].get_parameters(),
                              5e-4, 0.5, 0.999)

    ##############
    # Load model #
    ##############
    args.stage = 0
    if args.load_model is not None:
        check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r')
        to_restore = check_load.readlines()[-1].strip()
        load_file = os.path.join(args.log_dir, to_restore)
        if os.path.isfile(load_file):
            print("=> loading checkpoint '{}'".format(load_file))
            checkpoint = mindspore.load_checkpoint(
                load_file)  # MPS map_location='cpu'#
            for _ in range(int(checkpoint['stage'])):
                generator.progress()
                discriminator.progress()
            args.stage = checkpoint['stage']
            args.img_to_use = checkpoint['img_to_use']
            discriminator.load_state_dict(checkpoint['D_state_dict'])
            generator.load_state_dict(checkpoint['G_state_dict'])
            # MPS Adm.load_state_dict是否存在
            d_opt.load_state_dict(checkpoint['d_optimizer'])
            g_opt.load_state_dict(checkpoint['g_optimizer'])
            print("=> loaded checkpoint '{}' (stage {})".format(
                load_file, checkpoint['stage']))
        else:
            print("=> no checkpoint found at '{}'".format(args.log_dir))

    ###########
    # Dataset #
    ###########
    train_dataset, _ = get_dataset(args.dataset, args)
    train_sampler = None

    train_loader = mindspore.DatasetHelper(train_dataset)  # MPS 可能需要调参数

    ######################
    # Validate and Train #
    ######################
    op1 = mindspore.ops.Pad(((5, 5), (5, 5)))
    op2 = mindspore.ops.Pad(((5, 5), (5, 5)))
    z_fix_list = [op1(mindspore.ops.StandardNormal(3, args.size_list[0]))]
    zero_list = [
        op2(mindspore.ops.Zeros(3, args.size_list[zeros_idx]))
        for zeros_idx in range(1, args.num_scale + 1)
    ]
    z_fix_list = z_fix_list + zero_list

    if args.validation:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    elif args.test:
        validateSinGAN(train_loader, networks, args.stage, args,
                       {"z_rec": z_fix_list})
        return

    check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+")
    record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+")
    record_txt.write('DATASET\t:\t{}\n'.format(args.dataset))
    record_txt.write('GANTYPE\t:\t{}\n'.format(args.gantype))
    record_txt.write('IMGTOUSE\t:\t{}\n'.format(args.img_to_use))
    record_txt.close()
    networks = [discriminator, generator]

    for stage in range(args.stage, args.num_scale + 1):

        trainSinGAN(train_loader, networks, {
            "d_opt": d_opt,
            "g_opt": g_opt
        }, stage, args, {"z_rec": z_fix_list})
        validateSinGAN(train_loader, networks, stage, args,
                       {"z_rec": z_fix_list})
        discriminator.progress()
        generator.progress()

        # Update the networks at finest scale
        d_opt = mindspore.nn.Adam(
            discriminator.sub_discriminators[
                discriminator.current_scale].parameters(), 5e-4, 0.5, 0.999)
        g_opt = mindspore.nn.Adam(
            generator.sub_generators[generator.current_scale].parameters(),
            5e-4, 0.5, 0.999)
        ##############
        # Save model #
        ##############
        if stage == 0:
            check_list = open(os.path.join(args.log_dir, "checkpoint.txt"),
                              "a+")
        save_checkpoint(
            {
                'stage': stage + 1,
                'D_state_dict': discriminator.state_dict(),
                'G_state_dict': generator.state_dict(),
                'd_optimizer': d_opt.state_dict(),
                'g_optimizer': g_opt.state_dict(),
                'img_to_use': args.img_to_use
            }, check_list, args.log_dir, stage + 1)
        if stage == args.num_scale:
            check_list.close()