def main_worker(gpu, ngpus_per_node, args): if len(args.gpu) == 1: args.gpu = 0 else: args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:' + args.port, world_size=args.world_size, rank=args.rank) ################ # Define model # ################ # 4/3 : scale factor in the paper scale_factor = 4 / 3 tmp_scale = args.img_size_max / args.img_size_min args.num_scale = int(np.round(np.log(tmp_scale) / np.log(scale_factor))) args.size_list = [ int(args.img_size_min * scale_factor**i) for i in range(args.num_scale + 1) ] discriminator = Discriminator() generator = Generator(args.img_size_min, args.num_scale, scale_factor) networks = [discriminator, generator] if args.distributed: if args.gpu is not None: print('Distributed to', args.gpu) torch.cuda.set_device(args.gpu) networks = [x.cuda(args.gpu) for x in networks] args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int(args.workers / ngpus_per_node) networks = [ torch.nn.parallel.DistributedDataParallel( x, device_ids=[args.gpu], output_device=args.gpu) for x in networks ] else: networks = [x.cuda() for x in networks] networks = [ torch.nn.parallel.DistributedDataParallel(x) for x in networks ] elif args.gpu is not None: torch.cuda.set_device(args.gpu) networks = [x.cuda(args.gpu) for x in networks] else: networks = [torch.nn.DataParallel(x).cuda() for x in networks] discriminator, generator, = networks ###################### # Loss and Optimizer # ###################### if args.distributed: d_opt = torch.optim.Adam( discriminator.module.sub_discriminators[0].parameters(), 5e-4, (0.5, 0.999)) g_opt = torch.optim.Adam( generator.module.sub_generators[0].parameters(), 5e-4, (0.5, 0.999)) else: d_opt = torch.optim.Adam( discriminator.sub_discriminators[0].parameters(), 5e-4, (0.5, 0.999)) g_opt = torch.optim.Adam(generator.sub_generators[0].parameters(), 5e-4, (0.5, 0.999)) ############## # Load model # ############## args.stage = 0 if args.load_model is not None: check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r') to_restore = check_load.readlines()[-1].strip() load_file = os.path.join(args.log_dir, to_restore) if os.path.isfile(load_file): print("=> loading checkpoint '{}'".format(load_file)) checkpoint = torch.load(load_file, map_location='cpu') for _ in range(int(checkpoint['stage'])): generator.progress() discriminator.progress() networks = [discriminator, generator] if args.distributed: if args.gpu is not None: print('Distributed to', args.gpu) torch.cuda.set_device(args.gpu) networks = [x.cuda(args.gpu) for x in networks] args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int(args.workers / ngpus_per_node) networks = [ torch.nn.parallel.DistributedDataParallel( x, device_ids=[args.gpu], output_device=args.gpu) for x in networks ] else: networks = [x.cuda() for x in networks] networks = [ torch.nn.parallel.DistributedDataParallel(x) for x in networks ] elif args.gpu is not None: torch.cuda.set_device(args.gpu) networks = [x.cuda(args.gpu) for x in networks] else: networks = [torch.nn.DataParallel(x).cuda() for x in networks] discriminator, generator, = networks args.stage = checkpoint['stage'] args.img_to_use = checkpoint['img_to_use'] discriminator.load_state_dict(checkpoint['D_state_dict']) generator.load_state_dict(checkpoint['G_state_dict']) d_opt.load_state_dict(checkpoint['d_optimizer']) g_opt.load_state_dict(checkpoint['g_optimizer']) print("=> loaded checkpoint '{}' (stage {})".format( load_file, checkpoint['stage'])) else: print("=> no checkpoint found at '{}'".format(args.log_dir)) cudnn.benchmark = True ########### # Dataset # ########### train_dataset, _ = get_dataset(args.dataset, args) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) ###################### # Validate and Train # ###################### z_fix_list = [ F.pad(torch.randn(args.batch_size, 3, args.size_list[0], args.size_list[0]), [5, 5, 5, 5], value=0) ] zero_list = [ F.pad(torch.zeros(args.batch_size, 3, args.size_list[zeros_idx], args.size_list[zeros_idx]), [5, 5, 5, 5], value=0) for zeros_idx in range(1, args.num_scale + 1) ] z_fix_list = z_fix_list + zero_list if args.validation: validateSinGAN(train_loader, networks, args.stage, args, {"z_rec": z_fix_list}) return elif args.test: validateSinGAN(train_loader, networks, args.stage, args, {"z_rec": z_fix_list}) return if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+") record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+") record_txt.write('DATASET\t:\t{}\n'.format(args.dataset)) record_txt.write('GANTYPE\t:\t{}\n'.format(args.gantype)) record_txt.write('IMGTOUSE\t:\t{}\n'.format(args.img_to_use)) record_txt.close() for stage in range(args.stage, args.num_scale + 1): if args.distributed: train_sampler.set_epoch(stage) trainSinGAN(train_loader, networks, { "d_opt": d_opt, "g_opt": g_opt }, stage, args, {"z_rec": z_fix_list}) validateSinGAN(train_loader, networks, stage, args, {"z_rec": z_fix_list}) if args.distributed: discriminator.module.progress() generator.module.progress() else: discriminator.progress() generator.progress() networks = [discriminator, generator] if args.distributed: if args.gpu is not None: print('Distributed', args.gpu) torch.cuda.set_device(args.gpu) networks = [x.cuda(args.gpu) for x in networks] args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int(args.workers / ngpus_per_node) networks = [ torch.nn.parallel.DistributedDataParallel( x, device_ids=[args.gpu], output_device=args.gpu) for x in networks ] else: networks = [x.cuda() for x in networks] networks = [ torch.nn.parallel.DistributedDataParallel(x) for x in networks ] elif args.gpu is not None: torch.cuda.set_device(args.gpu) networks = [x.cuda(args.gpu) for x in networks] else: networks = [torch.nn.DataParallel(x).cuda() for x in networks] discriminator, generator, = networks # Update the networks at finest scale if args.distributed: for net_idx in range(generator.module.current_scale): for param in generator.module.sub_generators[ net_idx].parameters(): param.requires_grad = False for param in discriminator.module.sub_discriminators[ net_idx].parameters(): param.requires_grad = False d_opt = torch.optim.Adam( discriminator.module.sub_discriminators[ discriminator.current_scale].parameters(), 5e-4, (0.5, 0.999)) g_opt = torch.optim.Adam( generator.module.sub_generators[ generator.current_scale].parameters(), 5e-4, (0.5, 0.999)) else: for net_idx in range(generator.current_scale): for param in generator.sub_generators[net_idx].parameters(): param.requires_grad = False for param in discriminator.sub_discriminators[ net_idx].parameters(): param.requires_grad = False d_opt = torch.optim.Adam( discriminator.sub_discriminators[ discriminator.current_scale].parameters(), 5e-4, (0.5, 0.999)) g_opt = torch.optim.Adam( generator.sub_generators[generator.current_scale].parameters(), 5e-4, (0.5, 0.999)) ############## # Save model # ############## if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): if stage == 0: check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+") save_checkpoint( { 'stage': stage + 1, 'D_state_dict': discriminator.state_dict(), 'G_state_dict': generator.state_dict(), 'd_optimizer': d_opt.state_dict(), 'g_optimizer': g_opt.state_dict(), 'img_to_use': args.img_to_use }, check_list, args.log_dir, stage + 1) if stage == args.num_scale: check_list.close()
def main_worker(args): ################ # Define model # ################ # 4/3 : scale factor in the paper scale_factor = 4 / 3 tmp_scale = args.img_size_max / args.img_size_min args.num_scale = int(np.round(np.log(tmp_scale) / np.log(scale_factor))) args.size_list = [ int(args.img_size_min * scale_factor**i) for i in range(args.num_scale + 1) ] discriminator = Discriminator() generator = Generator(args.img_size_min, args.num_scale, scale_factor) ###################### # Loss and Optimizer # ###################### d_opt = mindspore.nn.Adam( discriminator.sub_discriminators[0].get_parameters(), 5e-4, 0.5, 0.999) g_opt = mindspore.nn.Adam(generator.sub_generators[0].get_parameters(), 5e-4, 0.5, 0.999) ############## # Load model # ############## args.stage = 0 if args.load_model is not None: check_load = open(os.path.join(args.log_dir, "checkpoint.txt"), 'r') to_restore = check_load.readlines()[-1].strip() load_file = os.path.join(args.log_dir, to_restore) if os.path.isfile(load_file): print("=> loading checkpoint '{}'".format(load_file)) checkpoint = mindspore.load_checkpoint( load_file) # MPS map_location='cpu'# for _ in range(int(checkpoint['stage'])): generator.progress() discriminator.progress() args.stage = checkpoint['stage'] args.img_to_use = checkpoint['img_to_use'] discriminator.load_state_dict(checkpoint['D_state_dict']) generator.load_state_dict(checkpoint['G_state_dict']) # MPS Adm.load_state_dict是否存在 d_opt.load_state_dict(checkpoint['d_optimizer']) g_opt.load_state_dict(checkpoint['g_optimizer']) print("=> loaded checkpoint '{}' (stage {})".format( load_file, checkpoint['stage'])) else: print("=> no checkpoint found at '{}'".format(args.log_dir)) ########### # Dataset # ########### train_dataset, _ = get_dataset(args.dataset, args) train_sampler = None train_loader = mindspore.DatasetHelper(train_dataset) # MPS 可能需要调参数 ###################### # Validate and Train # ###################### op1 = mindspore.ops.Pad(((5, 5), (5, 5))) op2 = mindspore.ops.Pad(((5, 5), (5, 5))) z_fix_list = [op1(mindspore.ops.StandardNormal(3, args.size_list[0]))] zero_list = [ op2(mindspore.ops.Zeros(3, args.size_list[zeros_idx])) for zeros_idx in range(1, args.num_scale + 1) ] z_fix_list = z_fix_list + zero_list if args.validation: validateSinGAN(train_loader, networks, args.stage, args, {"z_rec": z_fix_list}) return elif args.test: validateSinGAN(train_loader, networks, args.stage, args, {"z_rec": z_fix_list}) return check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+") record_txt = open(os.path.join(args.log_dir, "record.txt"), "a+") record_txt.write('DATASET\t:\t{}\n'.format(args.dataset)) record_txt.write('GANTYPE\t:\t{}\n'.format(args.gantype)) record_txt.write('IMGTOUSE\t:\t{}\n'.format(args.img_to_use)) record_txt.close() networks = [discriminator, generator] for stage in range(args.stage, args.num_scale + 1): trainSinGAN(train_loader, networks, { "d_opt": d_opt, "g_opt": g_opt }, stage, args, {"z_rec": z_fix_list}) validateSinGAN(train_loader, networks, stage, args, {"z_rec": z_fix_list}) discriminator.progress() generator.progress() # Update the networks at finest scale d_opt = mindspore.nn.Adam( discriminator.sub_discriminators[ discriminator.current_scale].parameters(), 5e-4, 0.5, 0.999) g_opt = mindspore.nn.Adam( generator.sub_generators[generator.current_scale].parameters(), 5e-4, 0.5, 0.999) ############## # Save model # ############## if stage == 0: check_list = open(os.path.join(args.log_dir, "checkpoint.txt"), "a+") save_checkpoint( { 'stage': stage + 1, 'D_state_dict': discriminator.state_dict(), 'G_state_dict': generator.state_dict(), 'd_optimizer': d_opt.state_dict(), 'g_optimizer': g_opt.state_dict(), 'img_to_use': args.img_to_use }, check_list, args.log_dir, stage + 1) if stage == args.num_scale: check_list.close()