def predict(args, model, data_loader, distance, hall): with torch.no_grad(): model.eval() hall.eval() # print(model) # print(hall) # each batch represent one episode (support data + query data) for i, (data, target) in enumerate(data_loader): with open(args.output_csv, 'a') as f: f.write("{}".format(i)) # split data into support and query data support_input = data[:args.n_way * args.n_shot].to(args.device) query_input = data[args.n_way * args.n_shot:].to(args.device) proto = model(support_input).view(args.n_way, args.n_shot, -1) new_proto = torch.empty( [args.n_way, args.n_shot+args.n_aug, args.dim]).to(args.device) for c in range(args.n_way): fake = hall(proto[c][0]) new_proto[c] = torch.cat([proto[c], fake], dim=0) new_proto = new_proto.mean(1) feature = model(query_input) distance = Distance(args) logits = distance(new_proto, feature) preds = torch.argmax(logits, dim=1) with open(args.output_csv, 'a') as f: for pred in preds: f.write(",{}".format(pred)) f.write("\n")
def val(data_loader, model, hall, criterion, args): with torch.no_grad(): model.eval() total_loss = [] total_acc = [] for _, data in enumerate(data_loader): image, _ = data support = {'image': image[:args.n_way*args.n_shot].to(args.device), 'label': torch.LongTensor([i//args.n_shot for i in range(args.n_way*args.n_shot)])} query = {'image': image[args.n_way*args.n_shot:].to(args.device), 'label': torch.LongTensor([i//args.n_query for i in range(args.n_way*args.n_query)]).to(args.device)} proto = model(support['image']).view(args.n_way, args.n_shot, -1) new_proto = torch.empty( [args.n_way, args.n_shot+args.n_aug, args.dim]).to(args.device) for c in range(args.n_way): fake = hall(proto[c][0]) new_proto[c] = torch.cat([proto[c], fake], dim=0) new_proto = new_proto.mean(1) feature = model(query['image']) distance = Distance(args) logits = distance(new_proto, feature) loss = criterion(logits, query['label']) total_loss.append(loss.item()) accuracy = calculate_acc(logits, query['label']) total_acc.append(accuracy) print('Validation: loss {:.3f}, acc {:.3f}\n'.format( np.mean(total_loss), np.mean(total_acc))) return np.mean(total_loss), np.mean(total_acc)
def train(data_loader, model, hall, criterion, optimzer, args): model.train() total_loss = [] total_acc = [] for step, data in enumerate(data_loader): image, _ = data support = {'image': image[:args.n_way*args.n_shot].to(args.device), 'label': torch.LongTensor([i//args.n_shot for i in range(args.n_way*args.n_shot)])} query = {'image': image[args.n_way*args.n_shot:].to(args.device), 'label': torch.LongTensor([i//args.n_query for i in range(args.n_way*args.n_query)]).to(args.device)} proto = model(support['image']).view(args.n_way, args.n_shot, -1) new_proto = torch.empty( [args.n_way, args.n_shot+args.n_aug, args.dim]).to(args.device) for c in range(args.n_way): fake = hall(proto[c][0]) new_proto[c] = torch.cat([proto[c], fake], dim=0) new_proto = new_proto.mean(1) feature = model(query['image']) distance = Distance(args) logits = distance(new_proto, feature) loss = criterion(logits, query['label']) total_loss.append(loss.item()) accuarcy = calculate_acc(logits, query['label']) total_acc.append(accuarcy) optimzer['m'].zero_grad() optimzer['h'].zero_grad() loss.backward() optimzer['m'].step() optimzer['h'].step() if step % 50 == 0: print('step {}: loss {:.3f}, acc {:.3f}'.format( step, np.mean(total_loss), np.mean(total_acc)), end='\r') print('Training: loss {:.3f}, acc {:.3f}'.format( np.mean(total_loss), np.mean(total_acc)), end='\n')
return parser.parse_args() if __name__ == '__main__': args = parse_args() set_seed(123) test_dataset = MiniDataset(args.test_csv, args.test_data_dir) test_loader = DataLoader(test_dataset, batch_size=args.n_way * (args.n_query + args.n_shot), num_workers=3, pin_memory=False, worker_init_fn=worker_init_fn, sampler=TestSampler(args.testcase_csv)) # TODO: load your model model = Convnet4(out_channels=args.dim).to(args.device) model.load_state_dict(torch.load(args.model)) distance = Distance(args) if args.distance_type == 'param': assert args.param distance.param.load_state_dict(torch.load(args.param)) with open(args.output_csv, 'w') as f: f.write('episode_id') for i in range(75): f.write(',query{}'.format(i)) f.write('\n') predict(args, model, test_loader, distance)
def train(data_loader, model, hall, disc, criterion, optimzer, args): model.train() total_loss = [] total_loss_d = [] total_loss_g = [] total_acc = [] for step, data in enumerate(data_loader): image, _ = data support = { 'image': image[:args.n_way * args.n_shot].to(args.device), 'label': torch.LongTensor( [i // args.n_shot for i in range(args.n_way * args.n_shot)]) } query = { 'image': image[args.n_way * args.n_shot:].to(args.device), 'label': torch.LongTensor([ i // args.n_query for i in range(args.n_way * args.n_query) ]).to(args.device) } proto = model(support['image']).view(args.n_way, args.n_shot, -1) # Train Model new_proto = torch.empty( [args.n_way, args.n_shot + args.n_aug, args.dim]).to(args.device) for c in range(args.n_way): fake = hall(proto[c][0]) new_proto[c] = torch.cat([proto[c], fake], dim=0) new_proto = new_proto.mean(1) feature = model(query['image']) distance = Distance(args) logits = distance(new_proto, feature) loss = criterion(logits, query['label']) total_loss.append(loss.item()) accuarcy = calculate_acc(logits, query['label']) total_acc.append(accuarcy) optimzer['m'].zero_grad() optimzer['h'].zero_grad() loss.backward(retain_graph=True) optimzer['m'].step() optimzer['h'].step() # Train Discriminator loss_real = 0 loss_fake = 0 for c in range(args.n_way): fake = hall(proto[c][0]) loss_real += torch.mean(disc(proto[c])) loss_fake += torch.mean(disc(fake)) loss_d = -loss_real + loss_fake total_loss_d.append(loss_d.item()) optimzer['d'].zero_grad() loss_d.backward(retain_graph=True) optimzer['d'].step() disc.weight_cliping() # Train Generator if step % 5 == 0: loss_g = 0 for c in range(args.n_way): fake = hall(proto[c][0]) loss_g += -torch.mean(disc(fake)) total_loss_g.append(loss_g.item()) optimzer['h'].zero_grad() loss_g.backward() optimzer['h'].step() if step % 50 == 0: print( 'step {}: loss_D {:.3f}, loss_G {:.3f}, loss {:.3f}, acc {:.3f}' .format(step, np.mean(total_loss_d), np.mean(total_loss_g), np.mean(total_loss), np.mean(total_acc)), end='\r') print('Training: loss_D {:.3f}, loss_G {:.3f}, loss {:.3f}, acc {:.3f}'. format(np.mean(total_loss_d), np.mean(total_loss_g), np.mean(total_loss), np.mean(total_acc)), end='\n')