def train(args): # use_cuda = args.num_gpus > 0 # logger.debug("Number of gpus available - {}".format(args.num_gpus)) # kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} device = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") # set the seed for generating random numbers torch.manual_seed(args.seed) if device.type == 'cuda': torch.cuda.manual_seed(args.seed) train_loader = _get_train_data_loader(args.resize) test_loader = _get_test_data_loader(args.resize) logger.debug("Processes {}/{} ({:.0f}%) of train data".format( len(train_loader.sampler), len(train_loader.dataset), 100. * len(train_loader.sampler) / len(train_loader.dataset))) model = FasterRCNN().to(device) params = [p for p in model.parameters() if p.requires_grad] optimizer = optim.Adam(params, lr=args.lr) for epoch in range(1, args.epochs + 1): print("Epoch: ", epoch) model.train() for batch_idx, (batch_images, batch_targets) in enumerate(train_loader): images = list(img.to(device) for img in batch_images) targets = [{k: v.to(device) for k, v in t.items()} for t in batch_targets] loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) model.zero_grad() losses.backward() optimizer.step() if batch_idx % args.log_interval == 0: logger.info( 'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format( epoch, batch_idx * len(images), len(train_loader.sampler), 100. * batch_idx / len(train_loader), losses.item())) save_model(model, args.model_directory, args.fn)
num_epochs = 2 for epoch in range(1, num_epochs): print("Epoch: ", epoch) model.train() i = 0 for idx, (batch_images, batch_targets) in enumerate(data_loader): i += 1 images = list(img.to(device) for img in batch_images) print(type(images), len(images)) targets = [{k: v.to(device) for k, v in t.items()} for t in batch_targets] print(type(targets), len(targets)) loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) model.zero_grad() losses.backward() optimizer.step() print("Iteration #: ", i / 20, "Loss: ", losses) # %% # Save model: torch.save(model.state_dict(), f"experiments/model_{datetime.datetime.now().strftime('%D')}.pt") # DOWNLOAD THE OLD MODEL.PT AND LOOK AT THE OUTPUT