def train(epoch): # lr_scheduler.step() model.train() correct = 0 for batch_idx, (data, target) in enumerate(train_loader): if target.numpy().any() >= 20 and target.numpy().any() < 0: print(target.numpy()) continue if use_cuda: data, target = data.cuda(), target.cuda() optimizer.zero_grad() y_pred_raw, feature_matrix, attention_map = model(data) # Update Feature Center feature_center_batch = F.normalize(feature_center[target], dim=-1) feature_center[target] += BETA * (feature_matrix.detach() - feature_center_batch) ################################## # Attention Cropping ################################## with torch.no_grad(): crop_images = batch_augment(data, attention_map[:, :1, :, :], mode='crop', theta=(0.4, 0.6), padding_ratio=0.1) # crop images forward y_pred_crop, _, _ = model(crop_images) ################################## # Attention Dropping ################################## with torch.no_grad(): drop_images = batch_augment(data, attention_map[:, 1:, :, :], mode='drop', theta=(0.2, 0.5)) # drop images forward y_pred_drop, _, _ = model(drop_images) # loss batch_loss = cross_entropy_loss(y_pred_raw, target) / 3. + \ cross_entropy_loss(y_pred_crop, target) / 3. + \ cross_entropy_loss(y_pred_drop, target) / 3. + \ center_loss(feature_matrix, feature_center_batch) # backward batch_loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), batch_loss.data.item()))
def validate(**kwargs): # Retrieve training configuration logs = kwargs['logs'] data_loader = kwargs['data_loader'] net = kwargs['net'] pbar = kwargs['pbar'] # metrics initialization loss_container.reset() raw_metric.reset() # begin validation start_time = time.time() net.eval() with torch.no_grad(): for i, (X, y, id) in enumerate(data_loader): # obtain data X = X.to(device) y = y.to(device) ################################## # Raw Image ################################## y_pred_raw, _, attention_map = net(X) ################################## # Object Localization and Refinement ################################## crop_images = batch_augment(X, attention_map, mode='crop', theta=0.1, padding_ratio=0.05) y_pred_crop, _, _ = net(crop_images) ################################## # Final prediction ################################## y_pred = (y_pred_raw + y_pred_crop) / 2. # loss batch_loss = cross_entropy_loss(y_pred, y) epoch_loss = loss_container(batch_loss.item()) # metrics: top-1,5 error epoch_acc = raw_metric(y_pred, y) # end of validation logs['val_{}'.format(loss_container.name)] = epoch_loss logs['val_{}'.format(raw_metric.name)] = epoch_acc end_time = time.time() batch_info = 'Val Loss {:.4f}, Val Acc ({:.2f}, {:.2f})'.format( epoch_loss, epoch_acc[0], epoch_acc[1]) pbar.set_postfix_str('{}, {}'.format(logs['train_info'], batch_info)) # write log for this epoch logging.info('Valid: {}, Time {:3.2f}'.format(batch_info, end_time - start_time)) logging.info('')
def write_csv(model, te_dataset, submission_df_path, options=None): print("Generating prediction...") device = get_device() te_dataloader = DataLoader(te_dataset, batch_size=batch_size, shuffle=False) submission_df = pd.read_csv(submission_df_path) test_pred = None model.eval() with torch.no_grad(): for inputs in te_dataloader: inputs = inputs.to(device) if options is not None and options.model == 4: y_pred_raw, _, attention_map = model(inputs) crop_images = batch_augment(inputs, attention_map, mode='crop', theta=0.1, padding_ratio=0.05) y_pred_crop, _, _ = model(crop_images) y_pred = (y_pred_raw + y_pred_crop) / 2. outputs = y_pred else: outputs = model(inputs) _, preds = torch.max(outputs, 1) if test_pred is None: test_pred = outputs.data.cpu() else: test_pred = torch.cat((test_pred, outputs.data.cpu()), dim=0) test_pred = torch.softmax(test_pred, dim=1, dtype=float) submission_df[['healthy', 'multiple_diseases', 'rust', 'scab']] = test_pred submission_df.to_csv(options.output_root + options.output_name + '.csv', index=False)
def validation(model, val_dataloader, criterion, epoch, options=None): device = get_device() model.to(device) model.eval() running_loss = 0. running_corrects = 0. with torch.no_grad(): for inputs, labels, _ in val_dataloader: inputs = inputs.to(device) labels = labels.to(device) #labels = labels.squeeze(-1) model.zero_grad() y_pred_raw, _, attention_map = model(inputs) crop_images = batch_augment(inputs, attention_map, mode='crop', theta=0.1, padding_ratio=0.05) y_pred_crop, _, _ = model(crop_images) y_pred = (y_pred_raw + y_pred_crop) / 2. outputs = y_pred if len(labels.shape) == 0: print("Error") loss = torch.tensor(0) else: loss = cross_entropy_loss(y_pred, labels) running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(labels == outputs.argmax(dim=1)) epoch_loss = running_loss / len(val_dataloader.dataset) epoch_acc = running_corrects.double() / len(val_dataloader.dataset) print('[Validation]Epoch: {}, Loss: {:.4f} Acc: {:.4f}'.format( epoch, epoch_loss, epoch_acc)) return epoch_acc, epoch_loss
def evaluate(test_loader, epoch, model): model.eval() test_loss, correct, total, tp, fp, tn, fn = 0, 0, 0, 0, 0, 0, 0 criterion = torch.nn.CrossEntropyLoss() # print(len(test_loader)) for data in tqdm(test_loader): try: image = data["image"].cuda() label = data["label"].cuda() except (OSError): # print("OSError of image. ") continue y_pred_raw, _, attention_map = model(image) crop_images = batch_augment(image, attention_map, mode='crop', theta=0.1, padding_ratio=0.05) y_pred_crop, _, _ = model(crop_images) y_pred = (y_pred_raw + y_pred_crop) / 2. loss = criterion(y_pred, label) test_loss += loss.item() _, predict = y_pred.max(1) total += label.size(0) correct += predict.eq(label).sum().item() tp += torch.sum(predict & label) fp += torch.sum(predict & (1 - label)) tn += torch.sum((1 - predict) & (1 - label)) fn += torch.sum((1 - predict) & label) acc = 100. * correct / total precision = 100.0 * tp / (tp + fp).float() recall = 100.0 * tp / (tp + fn).float() print( "==> [evaluate] epoch {}, loss = {}, acc = {}, precision = {}, recall = {}" .format(epoch, test_loss, acc, precision, recall)) return acc, precision, recall
def validation(val_loader): model.eval() validation_loss = 0 correct = 0 for data, target in val_loader: if use_cuda: data, target = data.cuda(), target.cuda() ################################## # Raw Image ################################## y_pred_raw, _, attention_map = model(data) ################################## # Object Localization and Refinement ################################## crop_images = batch_augment(data, attention_map, mode='crop', theta=0.1, padding_ratio=0.05) y_pred_crop, _, _ = model(crop_images) ################################## # Final prediction ################################## y_pred = (y_pred_raw + y_pred_crop) / 2. # loss batch_loss = cross_entropy_loss(y_pred, target) # metrics: top-1,5 error # epoch_acc = raw_metric(y_pred, y) # get the index of the max log-probability pred = y_pred.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).cpu().sum() batch_loss /= len(val_loader.dataset) print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( batch_loss, correct, len(val_loader.dataset), 100. * correct / len(val_loader.dataset)))
def train(**kwargs): # Retrieve training configuration logs = kwargs['logs'] data_loader = kwargs['data_loader'] net = kwargs['net'] feature_center = kwargs['feature_center'] optimizer = kwargs['optimizer'] pbar = kwargs['pbar'] # metrics initialization loss_container.reset() raw_metric.reset() crop_metric.reset() drop_metric.reset() # begin training start_time = time.time() net.train() for i, (X, y) in enumerate(data_loader): optimizer.zero_grad() # obtain data for training X = X.to(device) y = y.to(device) ################################## # Raw Image ################################## # raw images forward y_pred_raw, feature_matrix, attention_map = net(X) # Update Feature Center feature_center_batch = F.normalize(feature_center[y], dim=-1) feature_center[y] += 5e-2 * (feature_matrix.detach() - feature_center_batch) ################################## # Attention Cropping ################################## with torch.no_grad(): crop_images = batch_augment(X, attention_map[:, :1, :, :], mode='crop', theta=(0.4, 0.6), padding_ratio=0.1) # crop images forward y_pred_crop, _, _ = net(crop_images) ################################## # Attention Dropping ################################## with torch.no_grad(): drop_images = batch_augment(X, attention_map[:, 1:, :, :], mode='drop', theta=(0.2, 0.5)) # drop images forward y_pred_drop, _, _ = net(drop_images) # loss batch_loss = cross_entropy_loss(y_pred_raw, y) / 3. + \ cross_entropy_loss(y_pred_crop, y) / 3. + \ cross_entropy_loss(y_pred_drop, y) / 3. + \ center_loss(feature_matrix, feature_center_batch) # backward batch_loss.backward() optimizer.step() # metrics: loss and top-1,5 error with torch.no_grad(): epoch_loss = loss_container(batch_loss.item()) epoch_raw_acc = raw_metric(y_pred_raw, y) epoch_crop_acc = crop_metric(y_pred_crop, y) epoch_drop_acc = drop_metric(y_pred_drop, y) # end of this batch batch_info = 'Loss {:.4f}, Raw Acc ({:.2f}, {:.2f}), Crop Acc ({:.2f}, {:.2f}), Drop Acc ({:.2f}, {:.2f})'.format( epoch_loss, epoch_raw_acc[0], epoch_raw_acc[1], epoch_crop_acc[0], epoch_crop_acc[1], epoch_drop_acc[0], epoch_drop_acc[1]) pbar.update() pbar.set_postfix_str(batch_info) # end of this epoch logs['train_{}'.format(loss_container.name)] = epoch_loss logs['train_raw_{}'.format(raw_metric.name)] = epoch_raw_acc logs['train_crop_{}'.format(crop_metric.name)] = epoch_crop_acc logs['train_drop_{}'.format(drop_metric.name)] = epoch_drop_acc logs['train_info'] = batch_info end_time = time.time() # write log for this epoch logging.info('Train: {}, Time {:3.2f}'.format(batch_info, end_time - start_time))
def predict(image_path, model_param_path, save_path, img_save_name, resize=(224, 224), gen_hm=False): image = Image.open(image_path).convert('RGB') transform = transforms.Compose([ # transforms.Resize(size=(int(resize[0] / 0.875), int(resize[1] / 0.875))), transforms.Resize(size=(int(resize[0]), int(resize[1]))), # transforms.CenterCrop(resize), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(image) image = image.unsqueeze(0) net = WSDAN(num_classes=4) net.load_state_dict(torch.load(model_param_path)) net.eval() if 'gpu' in model_param_path: print("please make sure your computer has a GPU") device = torch.device("cuda") try: net.to(device) except: print("No GPU in the environment") else: device = torch.device("cpu") X = image X = X.to(device) # WS-DAN y_pred_raw, _, attention_maps = net(X) attention_maps = torch.mean(attention_maps, dim=1, keepdim=True) # Augmentation with crop_mask crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05) y_pred_crop, _, _ = net(crop_image) y_pred = (y_pred_raw + y_pred_crop) / 2. y_pred = F.softmax(y_pred) if gen_hm: attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3))) attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item()) # get heat attention maps heat_attention_maps = generate_heatmap(attention_maps) # raw_image, heat_attention, raw_attention raw_image = X.cpu() * STD + MEAN heat_attention_image = raw_image * 0.4 + heat_attention_maps * 0.6 raw_attention_image = raw_image * attention_maps for batch_idx in range(X.size(0)): rimg = ToPILImage(raw_image[batch_idx]) raimg = ToPILImage(raw_attention_image[batch_idx]) haimg = ToPILImage(heat_attention_image[batch_idx]) rimg.save( os.path.join(save_path, '{}_raw.jpg'.format(img_save_name))) raimg.save( os.path.join(save_path, '{}_raw_atten.jpg'.format(img_save_name))) haimg.save( os.path.join(save_path, '{}_heat_atten.jpg'.format(img_save_name))) df = pd.read_csv("../data/train.csv") for i in range(len(df)): # if df.loc[i, 'image_id'] in image_path: head, tail = os.path.split(image_path) if df.loc[i, 'image_id'] == tail[:-4]: label = torch.tensor( df.loc[i, ['healthy', 'multiple_diseases', 'rust', 'scab']]) break return y_pred, label
def main(): logging.basicConfig( format= '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) warnings.filterwarnings("ignore") try: ckpt = config.eval_ckpt except: logging.info('Set ckpt for evaluation in config.py') return ################################## # Dataset for testing ################################## # _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size) test_dataset = CarDataset('test') test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2, pin_memory=True) name2label, label2name = mapping('../training_labels.csv') ################################## # Initialize model ################################## net = WSDAN(num_classes=test_dataset.num_classes, M=config.num_attentions, net=config.net) # Load ckpt and get state_dict checkpoint = torch.load(ckpt) state_dict = checkpoint['state_dict'] # Load weights net.load_state_dict(state_dict) logging.info('Network loaded from {}'.format(ckpt)) ################################## # use cuda ################################## net.to(device) if torch.cuda.device_count() > 1: net = nn.DataParallel(net) ################################## # Prediction ################################## raw_accuracy = TopKAccuracyMetric(topk=(1, 5)) ref_accuracy = TopKAccuracyMetric(topk=(1, 5)) raw_accuracy.reset() ref_accuracy.reset() net.eval() logits = [] ids = [] with torch.no_grad(): pbar = tqdm(total=len(test_loader), unit=' batches') pbar.set_description('Validation') for i, (X, y, id) in enumerate(test_loader): X = X.to(device) y = y.to(device) ids.extend(id) # WS-DAN y_pred_raw, _, attention_maps = net(X) # Augmentation with crop_mask crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05) y_pred_crop, _, _ = net(crop_image) y_pred = (y_pred_raw + y_pred_crop) / 2. # Save the predictions logits.append(y_pred.cpu()) prediction = torch.argmax(torch.cat(logits, dim=0), dim=1) submission = pd.DataFrame( [ids, [label2name[x] for x in prediction.numpy()]]).transpose() submission.columns = ['id', 'label'] submission.to_csv(savepath + 'predictions.csv', index=False) if visualize: # reshape attention maps attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3))) attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item()) # get heat attention maps heat_attention_maps = generate_heatmap(attention_maps) # raw_image, heat_attention, raw_attention raw_image = X.cpu() * STD + MEAN heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5 raw_attention_image = raw_image * attention_maps for batch_idx in range(X.size(0)): rimg = ToPILImage(raw_image[batch_idx]) raimg = ToPILImage(raw_attention_image[batch_idx]) haimg = ToPILImage(heat_attention_image[batch_idx]) rimg.save( os.path.join( savepath, '%03d_raw.jpg' % (i * config.batch_size + batch_idx))) raimg.save( os.path.join( savepath, '%03d_raw_atten.jpg' % (i * config.batch_size + batch_idx))) haimg.save( os.path.join( savepath, '%03d_heat_atten.jpg' % (i * config.batch_size + batch_idx))) # Top K epoch_raw_acc = raw_accuracy(y_pred_raw, y) epoch_ref_acc = ref_accuracy(y_pred, y) # end of this batch batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format( epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0], epoch_ref_acc[1]) pbar.update() pbar.set_postfix_str(batch_info) pbar.close()
def main(): logging.basicConfig( format= '%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) warnings.filterwarnings("ignore") try: ckpt = sys.argv[1] except: logging.info('Usage: python3 eval.py <model.ckpt>') return ################################## # Dataset for testing ################################## test_dataset = CarDataset(phase='test', resize=448) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) ################################## # Initialize model ################################## net = WSDAN(num_classes=test_dataset.num_classes, M=32, net='inception_mixed_6e') # Load ckpt and get state_dict checkpoint = torch.load(ckpt) state_dict = checkpoint['state_dict'] # Load weights net.load_state_dict(state_dict) logging.info('Network loaded from {}'.format(ckpt)) ################################## # use cuda ################################## cudnn.benchmark = True net.to(device) net = nn.DataParallel(net) net.eval() ################################## # Prediction ################################## accuracy = TopKAccuracyMetric(topk=(1, 5)) accuracy.reset() with torch.no_grad(): pbar = tqdm(total=len(test_loader), unit=' batches') pbar.set_description('Validation') for i, (X, y) in enumerate(test_loader): X = X.to(device) y = y.to(device) # WS-DAN y_pred_raw, feature_matrix, attention_maps = net(X) # Augmentation with crop_mask crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1) y_pred_crop, _, _ = net(crop_image) pred = (y_pred_raw + y_pred_crop) / 2. if visualize: # reshape attention maps attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3))) attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item()) # get heat attention maps heat_attention_maps = generate_heatmap(attention_maps) # raw_image, heat_attention, raw_attention raw_image = X.cpu() * STD + MEAN heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5 raw_attention_image = raw_image * attention_maps for batch_idx in range(X.size(0)): rimg = ToPILImage(raw_image[batch_idx]) raimg = ToPILImage(raw_attention_image[batch_idx]) haimg = ToPILImage(heat_attention_image[batch_idx]) rimg.save( os.path.join(savepath, '%03d_raw.jpg' % (i + batch_idx))) raimg.save( os.path.join(savepath, '%03d_raw_atten.jpg' % (i + batch_idx))) haimg.save( os.path.join(savepath, '%03d_heat_atten.jpg' % (i + batch_idx))) # Top K epoch_acc = accuracy(pred, y) # end of this batch batch_info = 'Val Acc ({:.2f}, {:.2f})'.format( epoch_acc[0], epoch_acc[1]) pbar.update() pbar.set_postfix_str(batch_info) pbar.close() # show information for this epoch logging.info('Accuracy: %.2f, %.2f' % (epoch_acc[0], epoch_acc[1]))
def train(model, tr_dataloader, criterion, optimizer, epoch, options=None, feature_center=None): since = time.time() device = get_device() model.train() model.to(device) running_loss = 0. running_corrects = 0. for idx, (inputs, labels, _) in enumerate(tr_dataloader): inputs = inputs.to(device) labels = labels.to(device) #labels = labels.squeeze(-1) optimizer.zero_grad() model.zero_grad() y_pred_raw, feature_matrix, attention_map = model(inputs) # if len(labels.shape feature_center_batch = F.normalize(feature_center[labels], dim=-1) feature_center[labels] += 0.05 * (feature_matrix.detach() - feature_center_batch) ################################## # Attention Cropping ################################## with torch.no_grad(): crop_images = batch_augment(inputs, attention_map[:, :1, :, :], mode='crop', theta=(0.4, 0.6), padding_ratio=0.1) # crop images forward y_pred_crop, _, _ = model(crop_images) ################################## # Attention Dropping ################################## with torch.no_grad(): drop_images = batch_augment(inputs, attention_map[:, 1:, :, :], mode='drop', theta=(0.2, 0.5)) # drop images forward y_pred_drop, _, _ = model(drop_images) outputs = (y_pred_raw + y_pred_crop + y_pred_drop) / 3. # loss loss = cross_entropy_loss(y_pred_raw, labels) / 3. + \ cross_entropy_loss(y_pred_crop, labels) / 3. + \ cross_entropy_loss(y_pred_drop, labels) / 3. + \ center_loss(feature_matrix, feature_center_batch) loss.backward() optimizer.step() batch_loss = loss.item() * inputs.size(0) batch_corrects = torch.sum(labels == outputs.argmax(dim=1)) running_loss += batch_loss running_corrects += batch_corrects if idx % 9 == 1: print('[Train]Epoch: {}, idx: {}, Loss: {:.4f} Acc: {:.4f}'.format( epoch, idx, batch_loss / len(inputs), batch_corrects.float() / len(inputs))) epoch_loss = running_loss / len(tr_dataloader.dataset) epoch_acc = running_corrects.double() / len(tr_dataloader.dataset) print('[Train]Epoch: {}, Loss: {:.4f} Acc: {:.4f}'.format( epoch, epoch_loss, epoch_acc)) return epoch_acc, epoch_loss
def main(result_arr): logging.basicConfig( format='%(asctime)s: %(levelname)s: [%(filename)s:%(lineno)d]: %(message)s', level=logging.INFO) warnings.filterwarnings("ignore") try: ckpt = config.eval_ckpt except: logging.info('Set ckpt for evaluation in config.py') return ################################## # Dataset for testing ################################## _, test_dataset = get_trainval_datasets(config.tag, resize=config.image_size) test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2, pin_memory=True) ################################## # Initialize model ################################## net = WSDAN(num_classes=test_dataset.num_classes, M=config.num_attentions, net=config.net) # Load ckpt and get state_dict checkpoint = torch.load(ckpt) state_dict = checkpoint['state_dict'] # Load weights net.load_state_dict(state_dict) logging.info('Network loaded from {}'.format(ckpt)) ################################## # use cuda ################################## net.to(device) if torch.cuda.device_count() > 1: net = nn.DataParallel(net) ################################## # Prediction ################################## raw_accuracy = TopKAccuracyMetric(topk=(1, 5)) ref_accuracy = TopKAccuracyMetric(topk=(1, 5)) raw_accuracy.reset() ref_accuracy.reset() net.eval() with torch.no_grad(): pbar = tqdm(total=len(test_loader), unit=' batches') pbar.set_description('Validation') for i, (X, y) in enumerate(test_loader): X = X.to(device) y = y.to(device) # WS-DAN y_pred_raw, _, attention_maps = net(X) # Augmentation with crop_mask crop_image = batch_augment(X, attention_maps, mode='crop', theta=0.1, padding_ratio=0.05) y_pred_crop, _, _ = net(crop_image) y_pred = (y_pred_raw + y_pred_crop) / 2. d = {} reader = csv.reader(open('/home/naman/Documents/Assignment_Job/out_dict.csv', 'r')) for row in reader: k, v = row d[v] = k result.append(y_pred, d[y_pred)] if visualize: # reshape attention maps attention_maps = F.upsample_bilinear(attention_maps, size=(X.size(2), X.size(3))) attention_maps = torch.sqrt(attention_maps.cpu() / attention_maps.max().item()) # get heat attention maps heat_attention_maps = generate_heatmap(attention_maps) # raw_image, heat_attention, raw_attention raw_image = X.cpu() * STD + MEAN heat_attention_image = raw_image * 0.5 + heat_attention_maps * 0.5 raw_attention_image = raw_image * attention_maps for batch_idx in range(X.size(0)): rimg = ToPILImage(raw_image[batch_idx]) raimg = ToPILImage(raw_attention_image[batch_idx]) haimg = ToPILImage(heat_attention_image[batch_idx]) rimg.save(os.path.join(savepath, '%03d_raw.jpg' % (i * config.batch_size + batch_idx))) raimg.save(os.path.join(savepath, '%03d_raw_atten.jpg' % (i * config.batch_size + batch_idx))) haimg.save(os.path.join(savepath, '%03d_heat_atten.jpg' % (i * config.batch_size + batch_idx))) # Top K epoch_raw_acc = raw_accuracy(y_pred_raw, y) epoch_ref_acc = ref_accuracy(y_pred, y) # end of this batch batch_info = 'Val Acc: Raw ({:.2f}, {:.2f}), Refine ({:.2f}, {:.2f})'.format( epoch_raw_acc[0], epoch_raw_acc[1], epoch_ref_acc[0], epoch_ref_acc[1]) pbar.update() pbar.set_postfix_str(batch_info) pbar.close()
def model_train(model, input, eval_loader, test_loader, cfg): model.train() writer = SummaryWriter(log_dir=cfg.log_path) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0005, amsgrad=True) current_epoch = 0 global_step = 0 if cfg.load_ckp == True: model, optimizer, current_epoch, global_step, loss = load_model( model, optimizer, cfg.ckp_path) feature_center = torch.zeros(2, cfg.num_attentions * model.num_features) center_loss = CenterLoss() for epoch in range(current_epoch, cfg.NUM_EPOCHS, 1): running_loss = 0.0 _time = time() for i, data in enumerate(tqdm(input)): if i == len(input) - 1: break try: image = data["image"].cuda() label = data["label"].cuda() except (OSError): print("OSError of image. ") continue optimizer.zero_grad() y_pred_raw, feature_matrix, attention_map = model(image) ''' # Update Feature Center feature_center_batch = torch.nn.functional.normalize(feature_center[label], dim=-1) print(feature_center[label].shape, feature_matrix.detach().shape, feature_center_batch.shape) feature_center_batch[label] += cfg.beta * (feature_matrix.detach() - feature_center_batch) ''' # Attention Cropping with torch.no_grad(): crop_images = batch_augment(image, attention_map[:, :1, :, :], mode='crop', theta=(0.4, 0.6), padding_ratio=0.1) # crop images forward y_pred_crop, _, _ = model(crop_images) ''' # Attention Dropping with torch.no_grad(): drop_images = batch_augment(image, attention_map[:, 1:, :, :], mode='drop', theta=(0.2, 0.5)) # drop images forward y_pred_drop, _, _ = model(drop_images) ''' loss = criterion(y_pred_raw, label) / 3. + \ criterion(y_pred_crop, label) / 3 #criterion(y_pred_drop, label) / 3. + \ #0 #center_loss(feature_matrix, feature_center_batch) #print(loss) loss.backward() optimizer.step() running_loss += loss.item() if i % 200 == 0: batch_time = time() - _time print( "==> [train] epoch {}, batch {}, global_step {}. loss for 10 batches: {}, time for 10 batches: {}s" .format(epoch, i, global_step, running_loss, batch_time)) writer.add_scalar("scalar/loss", running_loss, global_step, time()) running_loss = 0.0 _time = time() global_step += 1 # TODO add save condition eg. acc if epoch % cfg.evaluate_epoch == 0: torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "global_step": global_step, 'loss': loss, }, os.path.join(cfg.save_path, "train_epoch_" + str(epoch) + ".tar")) print("==> [eval] on train dataset") acc_on_train, precision_on_train, recall_on_train = evaluate( train_loader, epoch, model) print("==> [eval] on valid dataset") acc_on_valid, precision_on_valid, recall_on_valid = evaluate( eval_loader, epoch, model) writer.add_scalar("scalar/accuracy_on_train", acc_on_train, global_step, time()) writer.add_scalar("scalar/accuracy_on_valid", acc_on_valid, global_step, time()) writer.add_scalar("scalar/precisoin_on_train", precision_on_train, global_step, time()) writer.add_scalar("scalar/precision_on_valid", precision_on_valid, global_step, time()) writer.add_scalar("scalar/recall_on_train", recall_on_train, global_step, time()) writer.add_scalar("scalar/recall_on_valid", recall_on_valid, global_step, time()) writer.close() print("Finish training.")