def train_model(model, model_test, criterion, optimizer, scheduler, num_epochs=25): since = time.time() #best_model_wts = model.state_dict() #best_acc = 0.0 warm_up = 0.1 # We start from the 0.1*lrRate warm_iteration = round(dataset_sizes['satellite'] / opt.batchsize) * opt.warm_epoch # first 5 epoch for epoch in range(num_epochs - start_epoch): epoch = epoch + start_epoch print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train']: if phase == 'train': model.train(True) # Set model to training mode else: model.train(False) # Set model to evaluate mode running_loss = 0.0 running_corrects = 0.0 running_corrects2 = 0.0 running_corrects3 = 0.0 # Iterate over data. for data, data2, data3, data4 in zip(dataloaders['satellite'], dataloaders['street'], dataloaders['drone'], dataloaders['google']): # get the inputs inputs, labels = data inputs2, labels2 = data2 inputs3, labels3 = data3 inputs4, labels4 = data4 now_batch_size, c, h, w = inputs.shape if now_batch_size < opt.batchsize: # skip the last batch continue if use_gpu: inputs = Variable(inputs.cuda().detach()) inputs2 = Variable(inputs2.cuda().detach()) inputs3 = Variable(inputs3.cuda().detach()) labels = Variable(labels.cuda().detach()) labels2 = Variable(labels2.cuda().detach()) labels3 = Variable(labels3.cuda().detach()) if opt.extra_Google: inputs4 = Variable(inputs4.cuda().detach()) labels4 = Variable(labels4.cuda().detach()) else: inputs, labels = Variable(inputs), Variable(labels) # zero the parameter gradients optimizer.zero_grad() # forward if phase == 'val': with torch.no_grad(): outputs, outputs2 = model(inputs, inputs2) else: if opt.views == 2: outputs, outputs2 = model(inputs, inputs2) elif opt.views == 3: if opt.extra_Google: outputs, outputs2, outputs3, outputs4 = model( inputs, inputs2, inputs3, inputs4) else: outputs, outputs2, outputs3 = model( inputs, inputs2, inputs3) _, preds = torch.max(outputs.data, 1) _, preds2 = torch.max(outputs2.data, 1) if opt.views == 2: loss = criterion(outputs, labels) + criterion( outputs2, labels2) elif opt.views == 3: _, preds3 = torch.max(outputs3.data, 1) loss = criterion(outputs, labels) + criterion( outputs2, labels2) + criterion(outputs3, labels3) if opt.extra_Google: loss += criterion(outputs4, labels4) # backward + optimize only if in training phase if epoch < opt.warm_epoch and phase == 'train': warm_up = min(1.0, warm_up + 0.9 / warm_iteration) loss *= warm_up if phase == 'train': if fp16: # we use optimier to backward loss with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() ########## if opt.moving_avg < 1.0: update_average(model_test, model, opt.moving_avg) # statistics if int(version[0]) > 0 or int( version[2] ) > 3: # for the new version like 0.4.0, 0.5.0 and 1.0.0 running_loss += loss.item() * now_batch_size else: # for the old version like 0.3.0 and 0.3.1 running_loss += loss.data[0] * now_batch_size running_corrects += float(torch.sum(preds == labels.data)) running_corrects2 += float(torch.sum(preds2 == labels2.data)) if opt.views == 3: running_corrects3 += float( torch.sum(preds3 == labels3.data)) epoch_loss = running_loss / dataset_sizes['satellite'] epoch_acc = running_corrects / dataset_sizes['satellite'] epoch_acc2 = running_corrects2 / dataset_sizes['satellite'] if opt.views == 2: print( '{} Loss: {:.4f} Satellite_Acc: {:.4f} Street_Acc: {:.4f}' .format(phase, epoch_loss, epoch_acc, epoch_acc2)) elif opt.views == 3: epoch_acc3 = running_corrects3 / dataset_sizes['satellite'] print( '{} Loss: {:.4f} Satellite_Acc: {:.4f} Street_Acc: {:.4f} Drone_Acc: {:.4f}' .format(phase, epoch_loss, epoch_acc, epoch_acc2, epoch_acc3)) y_loss[phase].append(epoch_loss) y_err[phase].append(1.0 - epoch_acc) # deep copy the model if phase == 'train': scheduler.step() last_model_wts = model.state_dict() if epoch % 20 == 19: save_network(model, opt.name, epoch) #draw_curve(epoch) time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) #print('Best val Acc: {:4f}'.format(best_acc)) #save_network(model_test, opt.name+'adapt', epoch) return model
def update_log(self, results): """ Updates the internal log, Algorithm.log_dict Args: - results (dictionary) """ results = self.sanitize_dict(results, to_out_device=False) # check all the fields exist for field in self.logged_fields: assert field in results, f"field {field} missing" # compute statistics for the current batch batch_log = {} with torch.no_grad(): for m in self.logged_metrics: if not self.no_group_logging: group_metrics, group_counts, worst_group_metric = m.compute_group_wise( results['y_pred'], results['y_true'], results['g'], self.grouper.n_groups, return_dict=False) batch_log[f'{self.group_prefix}{m.name}'] = group_metrics batch_log[m.agg_metric_field] = m.compute( results['y_pred'], results['y_true'], return_dict=False).item() count = results['y_true'].numel() # transfer other statistics in the results dictionary for field in self.logged_fields: if field.startswith(self.group_prefix) and self.no_group_logging: continue v = results[field] if isinstance(v, torch.Tensor) and v.numel() == 1: batch_log[field] = v.item() else: if isinstance(v, torch.Tensor): assert v.numel( ) == self.grouper.n_groups, "Current implementation deals only with group-wise statistics or a single-number statistic" assert field.startswith(self.group_prefix) batch_log[field] = v # update the log dict with the current batch if not self._has_log: # since it is the first log entry, just save the current log self.log_dict = batch_log if not self.no_group_logging: self.log_dict[self.group_count_field] = group_counts self.log_dict[self.count_field] = count else: # take a running average across batches otherwise for k, v in batch_log.items(): if k.startswith(self.group_prefix): if self.no_group_logging: continue self.log_dict[k] = update_average( self.log_dict[k], self.log_dict[self.group_count_field], v, group_counts) else: self.log_dict[k] = update_average( self.log_dict[k], self.log_dict[self.count_field], v, count) if not self.no_group_logging: self.log_dict[self.group_count_field] += group_counts self.log_dict[self.count_field] += count self._has_log = True
def train_model(model, model_test, criterion, optimizer, scheduler, num_epochs=25): since = time.time() # best_model_wts = model.state_dict() # best_acc = 0.0 warm_up = 0.1 # We start from the 0.1*lrRate warm_iteration = round(dataset_sizes['satellite'] / opt.batchsize) * opt.warm_epoch # first 5 epoch if opt.arcface: criterion_arcface = losses.ArcFaceLoss(num_classes=opt.nclasses, embedding_size=512) if opt.cosface: criterion_cosface = losses.CosFaceLoss(num_classes=opt.nclasses, embedding_size=512) if opt.circle: criterion_circle = CircleLoss(m=0.25, gamma=32) # gamma = 64 may lead to a better result. if opt.triplet: miner = miners.MultiSimilarityMiner() criterion_triplet = losses.TripletMarginLoss(margin=0.3) if opt.lifted: criterion_lifted = losses.GeneralizedLiftedStructureLoss(neg_margin=1, pos_margin=0) if opt.contrast: criterion_contrast = losses.ContrastiveLoss(pos_margin=0, neg_margin=1) if opt.sphere: criterion_sphere = losses.SphereFaceLoss(num_classes=opt.nclasses, embedding_size=512, margin=4) for epoch in range(num_epochs - start_epoch): epoch = epoch + start_epoch print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) # Each epoch has a training and validation phase for phase in ['train']: if phase == 'train': model.train(True) # Set model to training mode else: model.train(False) # Set model to evaluate mode running_loss = 0.0 running_corrects = 0.0 running_corrects2 = 0.0 running_corrects3 = 0.0 # Iterate over data. for data, data2, data3, data4 in zip(dataloaders['satellite'], dataloaders['street'], dataloaders['drone'], dataloaders['google']): # get the inputs inputs, labels = data inputs2, labels2 = data2 inputs3, labels3 = data3 inputs4, labels4 = data4 now_batch_size, c, h, w = inputs.shape if now_batch_size < opt.batchsize: # skip the last batch continue if use_gpu: inputs = Variable(inputs.cuda().detach()) inputs2 = Variable(inputs2.cuda().detach()) inputs3 = Variable(inputs3.cuda().detach()) labels = Variable(labels.cuda().detach()) labels2 = Variable(labels2.cuda().detach()) labels3 = Variable(labels3.cuda().detach()) if opt.extra_Google: inputs4 = Variable(inputs4.cuda().detach()) labels4 = Variable(labels4.cuda().detach()) else: inputs, labels = Variable(inputs), Variable(labels) # zero the parameter gradients optimizer.zero_grad() # forward if phase == 'val': with torch.no_grad(): outputs, outputs2 = model(inputs, inputs2) else: if opt.views == 2: outputs, outputs2 = model(inputs, inputs2) elif opt.views == 3: if opt.extra_Google: outputs, outputs2, outputs3, outputs4 = model(inputs, inputs2, inputs3, inputs4) else: outputs, outputs2, outputs3 = model(inputs, inputs2, inputs3) return_feature = opt.arcface or opt.cosface or opt.circle or opt.triplet or opt.contrast or opt.lifted or opt.sphere if opt.views == 2: _, preds = torch.max(outputs.data, 1) _, preds2 = torch.max(outputs2.data, 1) loss = criterion(outputs, labels) + criterion(outputs2, labels2) elif opt.views == 3: if return_feature: logits, ff = outputs logits2, ff2 = outputs2 logits3, ff3 = outputs3 fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) fnorm2 = torch.norm(ff2, p=2, dim=1, keepdim=True) fnorm3 = torch.norm(ff3, p=2, dim=1, keepdim=True) ff = ff.div(fnorm.expand_as(ff)) # 8*512,tensor ff2 = ff2.div(fnorm2.expand_as(ff2)) ff3 = ff3.div(fnorm3.expand_as(ff3)) loss = criterion(logits, labels) + criterion(logits2, labels2) + criterion(logits3, labels3) _, preds = torch.max(logits.data, 1) _, preds2 = torch.max(logits2.data, 1) _, preds3 = torch.max(logits3.data, 1) # Multiple perspectives are combined to calculate losses, please join ''--loss_merge'' in run.sh if opt.loss_merge: ff_all = torch.cat((ff, ff2, ff3), dim=0) labels_all = torch.cat((labels, labels2, labels3), dim=0) if opt.extra_Google: logits4, ff4 = outputs4 fnorm4 = torch.norm(ff4, p=2, dim=1, keepdim=True) ff4 = ff4.div(fnorm4.expand_as(ff4)) loss = criterion(logits, labels) + criterion(logits2, labels2) + criterion(logits3, labels3) +criterion(logits4, labels4) if opt.loss_merge: ff_all = torch.cat((ff_all, ff4), dim=0) labels_all = torch.cat((labels_all, labels4), dim=0) if opt.arcface: if opt.loss_merge: loss += criterion_arcface(ff_all, labels_all) else: loss += criterion_arcface(ff, labels) + criterion_arcface(ff2, labels2) + criterion_arcface(ff3, labels3) # /now_batch_size if opt.extra_Google: loss += criterion_arcface(ff4, labels4) # /now_batch_size if opt.cosface: if opt.loss_merge: loss += criterion_cosface(ff_all, labels_all) else: loss += criterion_cosface(ff, labels) + criterion_cosface(ff2, labels2) + criterion_cosface(ff3, labels3) # /now_batch_size if opt.extra_Google: loss += criterion_cosface(ff4, labels4) # /now_batch_size if opt.circle: if opt.loss_merge: loss += criterion_circle(*convert_label_to_similarity(ff_all, labels_all)) / now_batch_size else: loss += criterion_circle(*convert_label_to_similarity(ff, labels)) / now_batch_size + criterion_circle(*convert_label_to_similarity(ff2, labels2)) / now_batch_size + criterion_circle(*convert_label_to_similarity(ff3, labels3)) / now_batch_size if opt.extra_Google: loss += criterion_circle(*convert_label_to_similarity(ff4, labels4)) / now_batch_size if opt.triplet: if opt.loss_merge: hard_pairs_all = miner(ff_all, labels_all) loss += criterion_triplet(ff_all, labels_all, hard_pairs_all) else: hard_pairs = miner(ff, labels) hard_pairs2 = miner(ff2, labels2) hard_pairs3 = miner(ff3, labels3) loss += criterion_triplet(ff, labels, hard_pairs) + criterion_triplet(ff2, labels2, hard_pairs2) + criterion_triplet(ff3, labels3, hard_pairs3)# /now_batch_size if opt.extra_Google: hard_pairs4 = miner(ff4, labels4) loss += criterion_triplet(ff4, labels4, hard_pairs4) if opt.lifted: if opt.loss_merge: loss += criterion_lifted(ff_all, labels_all) else: loss += criterion_lifted(ff, labels) + criterion_lifted(ff2, labels2) + criterion_lifted(ff3, labels3) # /now_batch_size if opt.extra_Google: loss += criterion_lifted(ff4, labels4) if opt.contrast: if opt.loss_merge: loss += criterion_contrast(ff_all, labels_all) else: loss += criterion_contrast(ff, labels) + criterion_contrast(ff2,labels2) + criterion_contrast(ff3, labels3) # /now_batch_size if opt.extra_Google: loss += criterion_contrast(ff4, labels4) if opt.sphere: if opt.loss_merge: loss += criterion_sphere(ff_all, labels_all) / now_batch_size else: loss += criterion_sphere(ff, labels) / now_batch_size + criterion_sphere(ff2, labels2) / now_batch_size + criterion_sphere(ff3, labels3) / now_batch_size if opt.extra_Google: loss += criterion_sphere(ff4, labels4) else: _, preds = torch.max(outputs.data, 1) _, preds2 = torch.max(outputs2.data, 1) _, preds3 = torch.max(outputs3.data, 1) if opt.loss_merge: outputs_all = torch.cat((outputs, outputs2, outputs3), dim=0) labels_all = torch.cat((labels, labels2, labels3), dim=0) if opt.extra_Google: outputs_all = torch.cat((outputs_all, outputs4), dim=0) labels_all = torch.cat((labels_all, labels4), dim=0) loss = 4*criterion(outputs_all, labels_all) else: loss = criterion(outputs, labels) + criterion(outputs2, labels2) + criterion(outputs3, labels3) if opt.extra_Google: loss += criterion(outputs4, labels4) # backward + optimize only if in training phase if epoch < opt.warm_epoch and phase == 'train': warm_up = min(1.0, warm_up + 0.9 / warm_iteration) loss *= warm_up if phase == 'train': if fp16: # we use optimier to backward loss with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() ########## if opt.moving_avg < 1.0: update_average(model_test, model, opt.moving_avg) # statistics if int(version[0]) > 0 or int(version[2]) > 3: # for the new version like 0.4.0, 0.5.0 and 1.0.0 running_loss += loss.item() * now_batch_size else: # for the old version like 0.3.0 and 0.3.1 running_loss += loss.data[0] * now_batch_size running_corrects += float(torch.sum(preds == labels.data)) running_corrects2 += float(torch.sum(preds2 == labels2.data)) if opt.views == 3: running_corrects3 += float(torch.sum(preds3 == labels3.data)) epoch_loss = running_loss / dataset_sizes['satellite'] epoch_acc = running_corrects / dataset_sizes['satellite'] epoch_acc2 = running_corrects2 / dataset_sizes['satellite'] if opt.views == 2: print('{} Loss: {:.4f} Satellite_Acc: {:.4f} Street_Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc, epoch_acc2)) elif opt.views == 3: epoch_acc3 = running_corrects3 / dataset_sizes['satellite'] print('{} Loss: {:.4f} Satellite_Acc: {:.4f} Street_Acc: {:.4f} Drone_Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc, epoch_acc2, epoch_acc3)) y_loss[phase].append(epoch_loss) y_err[phase].append(1.0 - epoch_acc) # deep copy the model if phase == 'train': scheduler.step() last_model_wts = model.state_dict() if epoch % 20 == 19: save_network(model, opt.name, epoch) # draw_curve(epoch) time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) # print('Best val Acc: {:4f}'.format(best_acc)) # save_network(model_test, opt.name+'adapt', epoch) return model
def train(self, generator, discriminator, train_data_dir): # convert device generator.to(self.device) discriminator.to(self.device) # create a shadow copy of the generator generator_shadow = copy.deepcopy(generator) # initialize the gen_shadow weights equal to the # weights of gen update_average(generator_shadow, generator, beta=0) generator_shadow.train() optimizer_G = Adam(generator.parameters(), lr=self.learning_rate, betas=self.betas) optimizer_D = Adam(discriminator.parameters(), lr=self.learning_rate, betas=self.betas) image_size = 2 ** (self.depth + 1) print("Construct dataset") train_dataset = ImageDataset(train_data_dir, image_size) print("Construct optimizers") now = datetime.datetime.now() checkpoint_dir = os.path.join(self.checkpoint_root, f"{now.strftime('%Y%m%d_%H%M')}-progan") sample_dir = os.path.join(checkpoint_dir, "sample") os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(sample_dir, exist_ok=True) # training roop loss = losses.WGAN_GP(discriminator) print("Training Starts.") start_time = time.time() logger = Logger() iterations = 0 fixed_input = torch.randn(16, self.latent_dim).to(self.device) fixed_input /= torch.sqrt(torch.sum(fixed_input*fixed_input + 1e-8, dim=1, keepdim=True)) for current_depth in range(self.depth): current_res = 2 ** (current_depth + 1) print(f"Currently working on Depth: {current_depth}") current_res = np.power(2, current_depth + 2) print(f"Current resolution: {current_res} x {current_res}") train_dataloader = DataLoader(train_dataset, batch_size=self.batch_sizes[current_depth], shuffle=True, drop_last=True, num_workers=4) ticker = 1 for epoch in range(self.num_epochs[current_depth]): print(f"epoch {epoch}") total_batches = len(train_dataloader) fader_point = int((self.fade_ins[current_depth] / 100) * self.num_epochs[current_depth] * total_batches) for i, batch in enumerate(train_dataloader): alpha = ticker / fader_point if ticker <= fader_point else 1 images = batch.to(self.device) gan_input = torch.randn(images.shape[0], self.latent_dim).to(self.device) gan_input /= torch.sqrt(torch.sum(gan_input*gan_input + 1e-8, dim=1, keepdim=True)) real_samples = progressive_downsampling(images, current_depth, self.depth, alpha) # generate a batch of samples fake_samples = generator(gan_input, current_depth, alpha).detach() d_loss = loss.dis_loss(real_samples, fake_samples, current_depth, alpha) # optimize discriminator optimizer_D.zero_grad() d_loss.backward() optimizer_D.step() d_loss_val = d_loss.item() gan_input = torch.randn(images.shape[0], self.latent_dim).to(self.device) gan_input /= torch.sqrt(torch.sum(gan_input*gan_input + 1e-8, dim=1, keepdim=True)) fake_samples = generator(gan_input, current_depth, alpha) g_loss = loss.gen_loss(real_samples, fake_samples, current_depth, alpha) optimizer_G.zero_grad() g_loss.backward() optimizer_G.step() update_average(generator_shadow, generator, beta=0.999) g_loss_val = g_loss.item() elapsed = time.time() - start_time logger.log(depth=current_depth, epoch=epoch, i=i, iterations=iterations, loss_G=g_loss_val, loss_D=d_loss_val, time=elapsed) ticker += 1 iterations += 1 logger.output_log(f"{checkpoint_dir}/log.csv") generator.eval() with torch.no_grad(): sample_images = generator(fixed_input, current_depth, alpha) save_samples(sample_images, current_depth, sample_dir, current_depth, epoch, image_size) generator.train() if current_depth == self.depth: torch.save(generator.state_dict(), f"{checkpoint_dir}/{current_depth}_{epoch}_generator.pth")