def main(gpu, args): rank = args.nr * args.gpus + gpu if args.nodes > 1: dist.init_process_group("nccl", rank=rank, world_size=args.world_size) torch.cuda.set_device(gpu) torch.manual_seed(args.seed) np.random.seed(args.seed) if args.dataset == "STL10": train_dataset = torchvision.datasets.STL10( args.dataset_dir, split="unlabeled", download=True, transform=TransformsSimCLR(size=args.image_size), ) elif args.dataset == "CIFAR10": train_dataset = torchvision.datasets.CIFAR10( args.dataset_dir, download=True, transform=TransformsSimCLR(size=args.image_size), ) else: raise NotImplementedError if args.nodes > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.world_size, rank=rank, shuffle=True) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), drop_last=True, num_workers=args.workers, sampler=train_sampler, ) # initialize ResNet encoder = get_resnet(args.resnet, pretrained=False) n_features = encoder.fc.in_features # get dimensions of fc layer # initialize model model = SimCLR(args, encoder, n_features) if args.reload: model_fp = os.path.join(args.model_path, "checkpoint_{}.tar".format(args.epoch_num)) model.load_state_dict( torch.load(model_fp, map_location=args.device.type)) model = model.to(args.device) # optimizer / loss optimizer, scheduler = load_optimizer(args, model) criterion = NT_Xent(args.batch_size, args.temperature, args.device, args.world_size) # DDP / DP if args.dataparallel: model = convert_model(model) model = DataParallel(model) else: if args.nodes > 1: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DDP(model, device_ids=[gpu]) model = model.to(args.device) writer = None if args.nr == 0: writer = SummaryWriter() args.global_step = 0 args.current_epoch = 0 for epoch in range(args.start_epoch, args.epochs): lr = optimizer.param_groups[0]["lr"] loss_epoch = train(args, train_loader, model, criterion, optimizer, writer) if args.nr == 0 and scheduler: scheduler.step() if args.nr == 0 and epoch % 10 == 0: save_model(args, model, optimizer) if args.nr == 0: writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch) writer.add_scalar("Misc/learning_rate", lr, epoch) print( f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}" ) args.current_epoch += 1 ## end training save_model(args, model, optimizer)
def main(gpu, args): rank = args.nr * args.gpus + gpu if args.nodes > 1: dist.init_process_group("nccl", rank=rank, world_size=args.world_size) torch.cuda.set_device(gpu) torch.manual_seed(args.seed) np.random.seed(args.seed) if args.dataset == "STL10": train_dataset = torchvision.datasets.STL10( args.dataset_dir, split="unlabeled", download=True, transform=TransformsSimCLR(size=args.image_size), ) elif args.dataset == "CIFAR10": train_dataset = torchvision.datasets.CIFAR10( args.dataset_dir, download=True, transform=TransformsSimCLR(size=args.image_size), ) else: raise NotImplementedError if args.nodes > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=args.world_size, rank=rank, shuffle=True) else: train_sampler = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), drop_last=True, num_workers=args.workers, sampler=train_sampler, ) # initialize ResNet encoder = get_resnet(args.resnet, pretrained=False) n_features = encoder.fc.in_features # get dimensions of fc layer # initialize model model = SimCLR(args, encoder, n_features) if args.reload: model_fp = os.path.join(args.model_path, "checkpoint_{}.tar".format(args.epoch_num)) print(model_fp) model.load_state_dict( torch.load(model_fp, map_location=args.device.type)) model = model.to(args.device) # optimizer / loss optimizer, scheduler = load_optimizer(args, model) criterion = NT_Xent(args.batch_size, args.temperature, args.device, args.world_size) # DDP / DP if args.dataparallel: model = convert_model(model) model = DataParallel(model) else: if args.nodes > 1: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = DDP(model, device_ids=[gpu]) model = model.to(args.device) writer = None if args.nr == 0: writer = SummaryWriter() #added by @IvanKruzhilov decoder = Decoder(3, 3, args.image_size) optimizer_decoder = torch.optim.Adam(decoder.parameters(), lr=0.001) #decoder.load_state_dict(torch.load('save/decoder_my_algorithm_augmented.pt')) decoder = decoder.to(args.device) args.global_step = 0 args.current_epoch = 0 for epoch in range(args.start_epoch, args.epochs): lr = optimizer.param_groups[0]["lr"] scatter_radius = 0.2 random_fake = None #set in train fucntion now loss_epoch, loss_epoch_decoder, penalty_epoch = \ train(args, train_loader, model, decoder, criterion, optimizer, \ optimizer_decoder, writer, random_fake, scatter_radius) loss_mean, bce_mean = train_autoencoder(model, decoder, train_loader, None, \ optimizer_decoder, freeze_encoder=True) if args.nr == 0 and scheduler: scheduler.step() if args.nr == 0 and epoch % 5 == 0: save_model(args, model, optimizer) torch.save( decoder.state_dict(), os.path.join(args.model_path, 'decoder{0}.pt'.format(epoch))) if epoch % 10 == 0: decoder = Decoder(3, 3, args.image_size) optimizer_decoder = torch.optim.Adam(decoder.parameters(), lr=0.001) decoder = decoder.to(args.device) if args.nr == 0: writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch) writer.add_scalar("Misc/learning_rate", lr, epoch) mean_loss = loss_epoch / len(train_loader) mean_loss_decoder = loss_epoch_decoder / len(train_loader) mean_penalty = penalty_epoch / len(train_loader) print( f"Epoch [{epoch}/{args.epochs}]\t Loss: {mean_loss}\t decoder loss: {mean_loss_decoder}\t \ penalty: {mean_penalty}\t lr: {round(lr, 5)}") print('loss: ', loss_mean, 'mse: ', bce_mean) args.current_epoch += 1 ## end training save_model(args, model, optimizer)